diff --git a/cmd/serve.go b/cmd/serve.go index a9d5a78..c63a4ff 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -19,8 +19,13 @@ var serveCmd = &cobra.Command{ log.Fatal("failed to get server config:", err) } + // Create Server + srv, err := server.NewServer(cfg) + if err != nil { + log.Fatal("failed to create server:", err) + } + // Start Server - srv := server.NewServer(cfg) if err := srv.Start(); err != nil { log.Fatal("failed to start server:", err) } diff --git a/config/config.go b/config/config.go index bb08335..58e6abf 100644 --- a/config/config.go +++ b/config/config.go @@ -1,7 +1,9 @@ package config import ( + "errors" "fmt" + "net/url" "os" "reflect" "strings" @@ -17,13 +19,19 @@ type ConfigDef struct { } type BaseConfig struct { - ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"` + ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"` APIKey string `json:"api_key" description:"API Key for the conduit API"` } func (c *BaseConfig) Validate() error { if c.APIKey == "" { - return fmt.Errorf("api_key is required") + return errors.New("api_key is required") + } + if c.ServerAddress == "" { + return errors.New("server is required") + } + if _, err := url.Parse(c.ServerAddress); err != nil { + return fmt.Errorf("server is invalid: %w", err) } return nil } @@ -78,7 +86,7 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) { cfg := &ClientConfig{ BaseConfig: BaseConfig{ - ServerAddress: cfgValues["address"], + ServerAddress: cfgValues["server"], APIKey: cfgValues["api_key"], }, TunnelName: cfgValues["name"], diff --git a/server/server.go b/server/server.go index 7d7a9cb..aa4dd24 100644 --- a/server/server.go +++ b/server/server.go @@ -3,10 +3,12 @@ package server import ( "bufio" "bytes" + "errors" "fmt" "io" "net" "net/http" + "net/url" "strings" "sync" "time" @@ -24,22 +26,32 @@ type TunnelConnection struct { } type Server struct { - tunnels map[string]*TunnelConnection + host string + cfg *config.ServerConfig + mu sync.RWMutex + upgrader websocket.Upgrader - cfg *config.ServerConfig - mu sync.RWMutex + tunnels map[string]*TunnelConnection } -func NewServer(cfg *config.ServerConfig) *Server { +func NewServer(cfg *config.ServerConfig) (*Server, error) { + serverURL, err := url.Parse(cfg.ServerAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse server address: %v", err) + } else if serverURL.Host == "" { + return nil, errors.New("invalid server address") + } + return &Server{ cfg: cfg, + host: serverURL.Host, tunnels: make(map[string]*TunnelConnection), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, - } + }, nil } func (s *Server) Start() error { @@ -64,29 +76,6 @@ func (s *Server) Start() error { } } -func (s *Server) extractSubdomain(peakReader io.Reader) string { - // Read Request - req, err := http.ReadRequest(bufio.NewReader(peakReader)) - if err != nil { - return "" - } - defer req.Body.Close() - - // Extract Host - host := req.Host - if idx := strings.Index(host, ":"); idx != -1 { - host = host[:idx] - } - - // Extract Subdomain - parts := strings.Split(host, ".") - if len(parts) > 1 { - return parts[0] - } - - return "" -} - func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) { s.mu.RLock() count := len(s.tunnels) @@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne } } -func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) { - peek := make([]byte, 8192) - n, err := conn.Read(peek) - if err != nil { - return nil, nil, err - } - - peekedData := peek[:n] - combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn) - return bytes.NewReader(peekedData), combinedReader, nil -} - func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() - // Detect Tunnel - peakReader, allReader, _ := s.peekData(conn) - if subdomain := s.extractSubdomain(peakReader); subdomain != "" { - s.mu.RLock() - tunnelConn, exists := s.tunnels[subdomain] - s.mu.RUnlock() + // Capture Consumed Data - When determining where to route the request, we + // have to read the host headers. This requires reading from the buffer, so + // if we later decide to tunnel the TCP connection we need to reconstruct the + // data from the buffer. + var capturedData bytes.Buffer + teeReader := io.TeeReader(conn, &capturedData) + bufReader := bufio.NewReader(teeReader) - if exists { - log.Infof("relaying %s to tunnel", subdomain) - s.proxyRawConnection(conn, tunnelConn, allReader) - } - return - } - - // Control Endpoints - s.handleAsHTTP(conn, allReader) -} - -func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) { // Create HTTP Request & Writer w := &connResponseWriter{conn: conn} - r, err := http.ReadRequest(bufio.NewReader(allReader)) + r, err := http.ReadRequest(bufReader) if err != nil { w.WriteHeader(http.StatusBadRequest) return } + defer r.Body.Close() + // Validate Host + if !strings.Contains(r.Host, s.host) { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unknown host: %s", r.Host) + return + } + + // Extract Subdomain + subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") + if strings.Count(subdomain, ".") != 0 { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) + return + } + + // Handle Control Endpoints + if subdomain == "" { + s.handleAsHTTP(w, r) + return + } + + // Handle Tunnels + s.mu.RLock() + tunnelConn, exists := s.tunnels[subdomain] + s.mu.RUnlock() + if exists { + log.Infof("relaying %s to tunnel", subdomain) + + // Reconstruct Data & Proxy Connection + allReader := io.MultiReader(&capturedData, r.Body) + s.proxyRawConnection(conn, tunnelConn, allReader) + } +} + +func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { // Authorize Control Endpoints apiKey := r.URL.Query().Get("apiKey") if apiKey != s.cfg.APIKey {