diff --git a/e2e_test.go b/e2e_test.go index 8a948e6..9f705cc 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -505,3 +505,360 @@ func TestServerGracefulShutdown(t *testing.T) { t.Error("expected server port to be closed after shutdown") } } + +// ---------- HTTP Response Quality Tests ---------- + +func TestHTTPResponseHasProperHeaders(t *testing.T) { + apiKey := "test-key-headers" + + // Start Target HTTP Server + targetAddr, stopTarget := startHTTPTarget(t) + defer stopTarget() + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect Tunnel + stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", targetAddr), "hdr-test", apiKey) + defer stopTunnel() + + // Send Request Through Tunnel + resp := sendHTTPViaTunnel(t, serverAddr, "hdr-test", "GET", "/", "") + defer resp.Body.Close() + + // Verify Proper HTTP Semantics + if resp.Proto != "HTTP/1.1" { + t.Errorf("expected HTTP/1.1, got %s", resp.Proto) + } + if resp.Header.Get("X-Test-Header") != "present" { + t.Errorf("expected X-Test-Header: present, got %q", resp.Header.Get("X-Test-Header")) + } + if resp.ContentLength <= 0 && resp.TransferEncoding == nil { + t.Errorf("expected Content-Length or Transfer-Encoding, got neither") + } +} + +func TestHTTPControlEndpointResponseQuality(t *testing.T) { + apiKey := "test-key-ctrl-quality" + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // 404 on Unknown Tunnel — Verify stdlib response format + resp := sendHTTPViaTunnel(t, serverAddr, "nope", "GET", "/", "") + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("expected 404, got %d", resp.StatusCode) + } + + // Content-Type Should Be Set by http.Error + ct := resp.Header.Get("Content-Type") + if !strings.Contains(ct, "text/plain") { + t.Errorf("expected text/plain Content-Type, got %q", ct) + } + + // Content-Length Should Be Present + if resp.ContentLength <= 0 { + t.Errorf("expected positive Content-Length, got %d", resp.ContentLength) + } + + // Info Endpoint — JSON response quality + url := fmt.Sprintf("http://%s/_conduit/info?apiKey=%s", serverAddr, apiKey) + req, _ := http.NewRequest("GET", url, nil) + req.Host = serverAddr + + resp2, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) + if err != nil { + t.Fatalf("info request failed: %v", err) + } + defer resp2.Body.Close() + + if resp2.Header.Get("Content-Type") != "application/json" { + t.Errorf("expected application/json, got %q", resp2.Header.Get("Content-Type")) + } +} + +// ---------- Large Body Tests ---------- + +func TestHTTPLargeResponseBody(t *testing.T) { + apiKey := "test-key-large-resp" + + // Start Target That Returns a Large Body + largeBody := strings.Repeat("A", 1024*1024) // 1 MB + port := getFreePort(t) + addr := fmt.Sprintf("127.0.0.1:%d", port) + mux := http.NewServeMux() + mux.HandleFunc("/large", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(largeBody)) + }) + srv := &http.Server{Addr: addr, Handler: mux} + go srv.ListenAndServe() + defer srv.Close() + waitForPort(t, addr, 3*time.Second) + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect Tunnel + stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr), "large-resp", apiKey) + defer stopTunnel() + + // Request Large Response + resp := sendHTTPViaTunnel(t, serverAddr, "large-resp", "GET", "/large", "") + body := readBody(t, resp) + + if len(body) != len(largeBody) { + t.Errorf("expected %d bytes, got %d", len(largeBody), len(body)) + } +} + +func TestHTTPLargeRequestBody(t *testing.T) { + apiKey := "test-key-large-req" + + // Start Target That Echoes Body Size + port := getFreePort(t) + addr := fmt.Sprintf("127.0.0.1:%d", port) + mux := http.NewServeMux() + mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { + data, _ := io.ReadAll(r.Body) + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "size:%d", len(data)) + }) + srv := &http.Server{Addr: addr, Handler: mux} + go srv.ListenAndServe() + defer srv.Close() + waitForPort(t, addr, 3*time.Second) + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect Tunnel + stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr), "large-req", apiKey) + defer stopTunnel() + + // Send Large Request Body (512 KB) + largePayload := strings.Repeat("B", 512*1024) + resp := sendHTTPViaTunnel(t, serverAddr, "large-req", "POST", "/upload", largePayload) + body := readBody(t, resp) + + expected := fmt.Sprintf("size:%d", len(largePayload)) + if body != expected { + t.Errorf("expected %q, got %q", expected, body) + } +} + +// ---------- TCP Tunnel Tests ---------- + +func TestTCPTunnelEcho(t *testing.T) { + apiKey := "test-key-tcp" + + // Start TCP Echo Server + tcpAddr, stopTCP := startTCPEchoTarget(t) + defer stopTCP() + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect TCP Tunnel (bare host:port — the realistic way users would specify it) + stopTunnel := connectTunnel(t, serverAddr, tcpAddr, "tcp-test", apiKey) + defer stopTunnel() + + // Send Raw HTTP Request Through Tunnel to TCP Echo + conn, err := net.DialTimeout("tcp", serverAddr, 5*time.Second) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + // Write a Raw HTTP Request (the TCP echo will bounce it back) + reqLine := fmt.Sprintf("GET / HTTP/1.1\r\nHost: tcp-test.%s\r\n\r\n", serverAddr) + _, err = conn.Write([]byte(reqLine)) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + // Read Echoed Data + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 4096) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("failed to read echo: %v", err) + } + + response := string(buf[:n]) + if !strings.Contains(response, "GET / HTTP/1.1") { + t.Errorf("expected echoed request, got: %q", response) + } +} + +func TestTCPTunnelLargePayload(t *testing.T) { + apiKey := "test-key-tcp-large" + + // Start TCP Echo Server + tcpAddr, stopTCP := startTCPEchoTarget(t) + defer stopTCP() + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect TCP Tunnel (bare host:port) + stopTunnel := connectTunnel(t, serverAddr, tcpAddr, "tcp-large", apiKey) + defer stopTunnel() + + // Connect and Send Large Payload + conn, err := net.DialTimeout("tcp", serverAddr, 5*time.Second) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + // Route via Host Header Then Send 64 KB Payload + header := fmt.Sprintf("POST /data HTTP/1.1\r\nHost: tcp-large.%s\r\nContent-Length: 65536\r\n\r\n", serverAddr) + payload := header + strings.Repeat("X", 64*1024) + + _, err = conn.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write: %v", err) + } + + // Read All Echoed Data + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var received []byte + buf := make([]byte, 8192) + for len(received) < len(payload) { + n, err := conn.Read(buf) + if err != nil { + break + } + received = append(received, buf[:n]...) + } + + if len(received) != len(payload) { + t.Errorf("expected %d bytes echoed, got %d", len(payload), len(received)) + } +} + +// ---------- Concurrency Tests ---------- + +func TestConcurrentHTTPRequests(t *testing.T) { + apiKey := "test-key-concurrent" + + // Start Target HTTP Server + targetAddr, stopTarget := startHTTPTarget(t) + defer stopTarget() + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect Tunnel + stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", targetAddr), "conc-test", apiKey) + defer stopTunnel() + + // Fire 20 Concurrent Requests + const numRequests = 20 + var wg sync.WaitGroup + errors := make(chan string, numRequests) + + for i := range numRequests { + wg.Add(1) + go func(idx int) { + defer wg.Done() + path := fmt.Sprintf("/item/%d", idx) + resp := sendHTTPViaTunnel(t, serverAddr, "conc-test", "GET", path, "") + body := readBody(t, resp) + + expected := fmt.Sprintf("echo: GET %s", path) + if !strings.Contains(body, expected) { + errors <- fmt.Sprintf("request %d: expected %q, got %q", idx, expected, body) + } + if resp.StatusCode != http.StatusOK { + errors <- fmt.Sprintf("request %d: expected 200, got %d", idx, resp.StatusCode) + } + }(i) + } + + wg.Wait() + close(errors) + + for errMsg := range errors { + t.Error(errMsg) + } +} + +func TestConcurrentMultiTunnelRequests(t *testing.T) { + apiKey := "test-key-conc-multi" + + // Start Two Target Servers + target1Addr, stopTarget1 := startHTTPTarget(t) + defer stopTarget1() + + port2 := getFreePort(t) + addr2 := fmt.Sprintf("127.0.0.1:%d", port2) + mux2 := http.NewServeMux() + mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "server-two: %s", r.URL.Path) + }) + srv2 := &http.Server{Addr: addr2, Handler: mux2} + go srv2.ListenAndServe() + defer srv2.Close() + waitForPort(t, addr2, 3*time.Second) + + // Start Conduit Server + serverAddr, stopServer := startConduitServer(t, apiKey) + defer stopServer() + + // Connect Two Tunnels + stopTunnel1 := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", target1Addr), "cm-one", apiKey) + defer stopTunnel1() + + stopTunnel2 := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr2), "cm-two", apiKey) + defer stopTunnel2() + + // Fire Concurrent Requests to Both Tunnels + const perTunnel = 10 + var wg sync.WaitGroup + errors := make(chan string, perTunnel*2) + + for i := range perTunnel { + wg.Add(2) + + // Requests to Tunnel One + go func(idx int) { + defer wg.Done() + path := fmt.Sprintf("/a/%d", idx) + resp := sendHTTPViaTunnel(t, serverAddr, "cm-one", "GET", path, "") + body := readBody(t, resp) + if !strings.Contains(body, fmt.Sprintf("echo: GET %s", path)) { + errors <- fmt.Sprintf("tunnel-one req %d: got %q", idx, body) + } + }(i) + + // Requests to Tunnel Two + go func(idx int) { + defer wg.Done() + path := fmt.Sprintf("/b/%d", idx) + resp := sendHTTPViaTunnel(t, serverAddr, "cm-two", "GET", path, "") + body := readBody(t, resp) + if !strings.Contains(body, fmt.Sprintf("server-two: %s", path)) { + errors <- fmt.Sprintf("tunnel-two req %d: got %q", idx, body) + } + }(i) + } + + wg.Wait() + close(errors) + + for errMsg := range errors { + t.Error(errMsg) + } +} diff --git a/server/raw_http_response_writer.go b/server/raw_http_response_writer.go deleted file mode 100644 index 79babcc..0000000 --- a/server/raw_http_response_writer.go +++ /dev/null @@ -1,48 +0,0 @@ -package server - -import ( - "bufio" - "fmt" - "net" - "net/http" -) - -var _ http.ResponseWriter = (*rawHTTPResponseWriter)(nil) - -type rawHTTPResponseWriter struct { - conn net.Conn - header http.Header -} - -func (f *rawHTTPResponseWriter) Header() http.Header { - if f.header == nil { - f.header = make(http.Header) - } - return f.header -} - -func (f *rawHTTPResponseWriter) Write(data []byte) (int, error) { - return f.conn.Write(data) -} - -func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) { - // Write Status - status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - _, _ = f.conn.Write([]byte(status)) - - // Write Headers - for key, values := range f.header { - for _, value := range values { - _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) - } - } - - // End Headers - _, _ = f.conn.Write([]byte("\r\n")) -} - -func (f *rawHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - // Return Raw Connection & ReadWriter - rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) - return f.conn, rw, nil -} diff --git a/server/reconstructed_conn.go b/server/reconstructed_conn.go index 5993c69..f537e47 100644 --- a/server/reconstructed_conn.go +++ b/server/reconstructed_conn.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "io" "net" ) @@ -14,17 +13,17 @@ type reconstructedConn struct { reader io.Reader } -// Read reads from the reconstructed reader (captured data + original conn). +// Read reads from the reconstructed reader (prepended data + original conn). func (rc *reconstructedConn) Read(p []byte) (n int, err error) { return rc.reader.Read(p) } -// newReconstructedConn creates a reconstructed connection that replays captured data -// before reading from the original connection. -func newReconstructedConn(conn net.Conn, capturedData *bytes.Buffer) net.Conn { - allReader := io.MultiReader(capturedData, conn) +// newReconstructedConn creates a reconstructed connection that replays the provided +// readers in order before reading from the underlying connection. +func newReconstructedConn(conn net.Conn, readers ...io.Reader) net.Conn { + allReaders := append(readers, conn) return &reconstructedConn{ Conn: conn, - reader: allReader, + reader: io.MultiReader(allReaders...), } } diff --git a/server/server.go b/server/server.go index 38a057c..9be4608 100644 --- a/server/server.go +++ b/server/server.go @@ -1,14 +1,11 @@ package server import ( - "bufio" "bytes" "context" "encoding/json" "errors" "fmt" - "io" - "net" "net/http" "net/url" "strings" @@ -62,37 +59,104 @@ func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) { } func (s *Server) Start() error { - // Raw TCP Listener - This is necessary so we can conditionally either relay - // the raw TCP connection, or handle conduit control server API requests. - listener, err := net.Listen("tcp", s.cfg.BindAddress) - if err != nil { - return err + // HTTP Server - Uses stdlib http.Server for proper HTTP response handling + // including Content-Length, chunked encoding, and keep-alive semantics. + httpServer := &http.Server{ + Addr: s.cfg.BindAddress, + Handler: s, } - defer listener.Close() - // Context Cancellation - Close the listener when the context is cancelled - // so that Accept() unblocks and the loop exits cleanly. + // Context Cancellation - Gracefully shut down when the context is cancelled. go func() { <-s.ctx.Done() - listener.Close() + log.Info("conduit server shutting down") + httpServer.Close() }() - // Start Listening + // Start Server log.Infof("conduit server listening on %s", s.cfg.BindAddress) - for { - conn, err := listener.Accept() - if err != nil { - // Expected Error on Shutdown - if s.ctx.Err() != nil { - log.Info("conduit server shutting down") - return nil - } - log.WithError(err).Error("error accepting connection") - continue - } - - go s.handleRawConnection(conn) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err } + return nil +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Get True Host + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + r.RemoteAddr = xff + } + + // Validate Host + if !strings.Contains(r.Host, s.host) { + http.Error(w, fmt.Sprintf("unknown host: %s", r.Host), http.StatusBadRequest) + return + } + + // Extract Subdomain + tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") + if strings.Count(tunnelName, ".") != 0 { + http.Error(w, fmt.Sprintf("cannot tunnel nested subdomains: %s", r.Host), http.StatusBadRequest) + return + } + + // Handle Control Endpoints + if tunnelName == "" { + s.handleAsHTTP(w, r) + return + } + + // Handle Tunnel Requests + s.handleTunnelRequest(w, r, tunnelName) +} + +func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tunnelName string) { + // Get Tunnel + conduitTunnel, exists := s.tunnels.Get(tunnelName) + if !exists { + http.Error(w, fmt.Sprintf("unknown tunnel: %s", tunnelName), http.StatusNotFound) + return + } + + // Hijack Connection - Take over the raw TCP connection from the HTTP server + // so we can forward the full request (including body) through the tunnel. + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "hijack not supported", http.StatusInternalServerError) + return + } + conn, bufrw, err := hj.Hijack() + if err != nil { + http.Error(w, fmt.Sprintf("hijack failed: %v", err), http.StatusInternalServerError) + return + } + + // Re-Serialize Request Headers - The HTTP server already consumed the request + // from the connection. We re-serialize it so the tunnel client receives a + // complete HTTP request to forward to the local target. + var reqBuf bytes.Buffer + fmt.Fprintf(&reqBuf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto) + fmt.Fprintf(&reqBuf, "Host: %s\r\n", r.Host) + _ = r.Header.Write(&reqBuf) + reqBuf.WriteString("\r\n") + + // Reconstruct Connection - Combine re-serialized headers with any buffered + // body data (from the hijacked reader) and the raw connection. + reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw) + + // Create Stream + streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) + tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source()) + + // Add Stream + if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil { + log.WithError(err).Error("failed to add stream") + conn.Close() + return + } + + // Start Stream + conduitTunnel.StartStream(tunnelStream, streamID) } func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { @@ -122,79 +186,6 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write(d) } -func (s *Server) handleRawConnection(conn net.Conn) { - defer conn.Close() - - // Capture Consumed Data - When determining where to route the request, we - // have to read the host headers. This requires reading from the buffer, so - // if we later decide to tunnel the TCP connection we need to reconstruct the - // data from the buffer. - var capturedData bytes.Buffer - teeReader := io.TeeReader(conn, &capturedData) - bufReader := bufio.NewReader(teeReader) - - // Create HTTP Request & Writer - w := &rawHTTPResponseWriter{conn: conn} - r, err := http.ReadRequest(bufReader) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Validate Host - if !strings.Contains(r.Host, s.host) { - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprintf(w, "unknown host: %s", r.Host) - return - } - - // Extract Subdomain - tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") - if strings.Count(tunnelName, ".") != 0 { - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) - return - } - - // Get True Host - remoteHost := conn.RemoteAddr().String() - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - remoteHost = xff - } - r.RemoteAddr = remoteHost - - // Handle Control Endpoints - if tunnelName == "" { - s.handleAsHTTP(w, r) - return - } - - // Handle Tunnels - conduitTunnel, exists := s.tunnels.Get(tunnelName) - if !exists { - w.WriteHeader(http.StatusNotFound) - _, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName) - return - } - - // Create Stream - reconstructedConn := newReconstructedConn(conn, &capturedData) - streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) - tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source()) - - // Add Stream - if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil { - w.WriteHeader(http.StatusInternalServerError) - _, _ = fmt.Fprintf(w, "failed to add stream: %v", err) - log.WithError(err).Error("failed to add stream") - return - } - - // Start Stream - conduitTunnel.StartStream(tunnelStream, streamID) -} - func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { // Authorize Control Endpoints apiKey := r.URL.Query().Get("apiKey") @@ -219,15 +210,13 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Get Tunnel Name tunnelName := r.URL.Query().Get("tunnelName") if tunnelName == "" { - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte("Missing tunnelName parameter")) + http.Error(w, "Missing tunnelName parameter", http.StatusBadRequest) return } // Validate Unique if _, exists := s.tunnels.Get(tunnelName); exists { - w.WriteHeader(http.StatusConflict) - _, _ = w.Write([]byte("Tunnel already registered")) + http.Error(w, "Tunnel already registered", http.StatusConflict) return } diff --git a/tunnel/forwarder.go b/tunnel/forwarder.go index a8e5f01..ce8bbf6 100644 --- a/tunnel/forwarder.go +++ b/tunnel/forwarder.go @@ -21,23 +21,15 @@ type Forwarder interface { } func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) { - // Get Target URL + // Only parse as URL for HTTP targets. Bare host:port (e.g., "127.0.0.1:5432") + // is not a valid URL and should be treated as a raw TCP target. targetURL, err := url.Parse(target) - if err != nil { - return nil, err - } - - // Get Connection Builder - var forwarder Forwarder - switch targetURL.Scheme { - case "http", "https": - forwarder, err = newHTTPForwarder(targetURL, tunnelStore) - if err != nil { - return nil, err + if err == nil { + switch targetURL.Scheme { + case "http", "https": + return newHTTPForwarder(targetURL, tunnelStore) } - default: - forwarder = newTCPForwarder(target, tunnelStore) } - return forwarder, nil + return newTCPForwarder(target, tunnelStore), nil }