clean
This commit is contained in:
parent
2fba07f4b3
commit
08e1191ba3
@ -19,8 +19,13 @@ var serveCmd = &cobra.Command{
|
|||||||
log.Fatal("failed to get server config:", err)
|
log.Fatal("failed to get server config:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create Server
|
||||||
|
srv, err := server.NewServer(cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("failed to create server:", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Start Server
|
// Start Server
|
||||||
srv := server.NewServer(cfg)
|
|
||||||
if err := srv.Start(); err != nil {
|
if err := srv.Start(); err != nil {
|
||||||
log.Fatal("failed to start server:", err)
|
log.Fatal("failed to start server:", err)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -17,13 +19,19 @@ type ConfigDef struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type BaseConfig struct {
|
type BaseConfig struct {
|
||||||
ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"`
|
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
|
||||||
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BaseConfig) Validate() error {
|
func (c *BaseConfig) Validate() error {
|
||||||
if c.APIKey == "" {
|
if c.APIKey == "" {
|
||||||
return fmt.Errorf("api_key is required")
|
return errors.New("api_key is required")
|
||||||
|
}
|
||||||
|
if c.ServerAddress == "" {
|
||||||
|
return errors.New("server is required")
|
||||||
|
}
|
||||||
|
if _, err := url.Parse(c.ServerAddress); err != nil {
|
||||||
|
return fmt.Errorf("server is invalid: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -78,7 +86,7 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
|
|||||||
|
|
||||||
cfg := &ClientConfig{
|
cfg := &ClientConfig{
|
||||||
BaseConfig: BaseConfig{
|
BaseConfig: BaseConfig{
|
||||||
ServerAddress: cfgValues["address"],
|
ServerAddress: cfgValues["server"],
|
||||||
APIKey: cfgValues["api_key"],
|
APIKey: cfgValues["api_key"],
|
||||||
},
|
},
|
||||||
TunnelName: cfgValues["name"],
|
TunnelName: cfgValues["name"],
|
||||||
|
118
server/server.go
118
server/server.go
@ -3,10 +3,12 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -24,22 +26,32 @@ type TunnelConnection struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
tunnels map[string]*TunnelConnection
|
host string
|
||||||
upgrader websocket.Upgrader
|
|
||||||
cfg *config.ServerConfig
|
cfg *config.ServerConfig
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
tunnels map[string]*TunnelConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||||
|
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse server address: %v", err)
|
||||||
|
} else if serverURL.Host == "" {
|
||||||
|
return nil, errors.New("invalid server address")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(cfg *config.ServerConfig) *Server {
|
|
||||||
return &Server{
|
return &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
host: serverURL.Host,
|
||||||
tunnels: make(map[string]*TunnelConnection),
|
tunnels: make(map[string]*TunnelConnection),
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
@ -64,29 +76,6 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) extractSubdomain(peakReader io.Reader) string {
|
|
||||||
// Read Request
|
|
||||||
req, err := http.ReadRequest(bufio.NewReader(peakReader))
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer req.Body.Close()
|
|
||||||
|
|
||||||
// Extract Host
|
|
||||||
host := req.Host
|
|
||||||
if idx := strings.Index(host, ":"); idx != -1 {
|
|
||||||
host = host[:idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract Subdomain
|
|
||||||
parts := strings.Split(host, ".")
|
|
||||||
if len(parts) > 1 {
|
|
||||||
return parts[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
|
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
count := len(s.tunnels)
|
count := len(s.tunnels)
|
||||||
@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
|
|
||||||
peek := make([]byte, 8192)
|
|
||||||
n, err := conn.Read(peek)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peekedData := peek[:n]
|
|
||||||
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
|
|
||||||
return bytes.NewReader(peekedData), combinedReader, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleRawConnection(conn net.Conn) {
|
func (s *Server) handleRawConnection(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
// Detect Tunnel
|
// Capture Consumed Data - When determining where to route the request, we
|
||||||
peakReader, allReader, _ := s.peekData(conn)
|
// have to read the host headers. This requires reading from the buffer, so
|
||||||
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
|
// if we later decide to tunnel the TCP connection we need to reconstruct the
|
||||||
s.mu.RLock()
|
// data from the buffer.
|
||||||
tunnelConn, exists := s.tunnels[subdomain]
|
var capturedData bytes.Buffer
|
||||||
s.mu.RUnlock()
|
teeReader := io.TeeReader(conn, &capturedData)
|
||||||
|
bufReader := bufio.NewReader(teeReader)
|
||||||
|
|
||||||
if exists {
|
|
||||||
log.Infof("relaying %s to tunnel", subdomain)
|
|
||||||
s.proxyRawConnection(conn, tunnelConn, allReader)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Control Endpoints
|
|
||||||
s.handleAsHTTP(conn, allReader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
|
|
||||||
// Create HTTP Request & Writer
|
// Create HTTP Request & Writer
|
||||||
w := &connResponseWriter{conn: conn}
|
w := &connResponseWriter{conn: conn}
|
||||||
r, err := http.ReadRequest(bufio.NewReader(allReader))
|
r, err := http.ReadRequest(bufReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
|
||||||
|
// Validate Host
|
||||||
|
if !strings.Contains(r.Host, s.host) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract Subdomain
|
||||||
|
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||||
|
if strings.Count(subdomain, ".") != 0 {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Control Endpoints
|
||||||
|
if subdomain == "" {
|
||||||
|
s.handleAsHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Tunnels
|
||||||
|
s.mu.RLock()
|
||||||
|
tunnelConn, exists := s.tunnels[subdomain]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
if exists {
|
||||||
|
log.Infof("relaying %s to tunnel", subdomain)
|
||||||
|
|
||||||
|
// Reconstruct Data & Proxy Connection
|
||||||
|
allReader := io.MultiReader(&capturedData, r.Body)
|
||||||
|
s.proxyRawConnection(conn, tunnelConn, allReader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Authorize Control Endpoints
|
// Authorize Control Endpoints
|
||||||
apiKey := r.URL.Query().Get("apiKey")
|
apiKey := r.URL.Query().Get("apiKey")
|
||||||
if apiKey != s.cfg.APIKey {
|
if apiKey != s.cfg.APIKey {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user