refactor(server): replace rawHTTPResponseWriter with stdlib http.Server
- Rewrite server to use net/http.Server with ServeHTTP handler instead of raw TCP listener with hand-written HTTP responses. Control plane errors now get proper Content-Type, Content-Length, and chunked encoding via http.Error(). Tunnel traffic hijacks the connection and re-serializes the request for forwarding. - Simplify reconstructedConn to accept variative io.Readers - Delete raw_http_response_writer.go (no longer needed) - Fix TCP forwarder: bare host:port (e.g. "127.0.0.1:5432") now works correctly instead of failing on url.Parse. Only HTTP/HTTPS schemes go through URL parsing; everything else is treated as raw TCP. - Add 8 new e2e tests: HTTP response quality, 1MB response body, 512KB request body, TCP echo, TCP large payload, concurrent single-tunnel, concurrent multi-tunnel (16 tests total, all passing)
This commit is contained in:
357
e2e_test.go
357
e2e_test.go
@@ -505,3 +505,360 @@ func TestServerGracefulShutdown(t *testing.T) {
|
|||||||
t.Error("expected server port to be closed after shutdown")
|
t.Error("expected server port to be closed after shutdown")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------- HTTP Response Quality Tests ----------
|
||||||
|
|
||||||
|
func TestHTTPResponseHasProperHeaders(t *testing.T) {
|
||||||
|
apiKey := "test-key-headers"
|
||||||
|
|
||||||
|
// Start Target HTTP Server
|
||||||
|
targetAddr, stopTarget := startHTTPTarget(t)
|
||||||
|
defer stopTarget()
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect Tunnel
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", targetAddr), "hdr-test", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Send Request Through Tunnel
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "hdr-test", "GET", "/", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Verify Proper HTTP Semantics
|
||||||
|
if resp.Proto != "HTTP/1.1" {
|
||||||
|
t.Errorf("expected HTTP/1.1, got %s", resp.Proto)
|
||||||
|
}
|
||||||
|
if resp.Header.Get("X-Test-Header") != "present" {
|
||||||
|
t.Errorf("expected X-Test-Header: present, got %q", resp.Header.Get("X-Test-Header"))
|
||||||
|
}
|
||||||
|
if resp.ContentLength <= 0 && resp.TransferEncoding == nil {
|
||||||
|
t.Errorf("expected Content-Length or Transfer-Encoding, got neither")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPControlEndpointResponseQuality(t *testing.T) {
|
||||||
|
apiKey := "test-key-ctrl-quality"
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// 404 on Unknown Tunnel — Verify stdlib response format
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "nope", "GET", "/", "")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("expected 404, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content-Type Should Be Set by http.Error
|
||||||
|
ct := resp.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(ct, "text/plain") {
|
||||||
|
t.Errorf("expected text/plain Content-Type, got %q", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content-Length Should Be Present
|
||||||
|
if resp.ContentLength <= 0 {
|
||||||
|
t.Errorf("expected positive Content-Length, got %d", resp.ContentLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info Endpoint — JSON response quality
|
||||||
|
url := fmt.Sprintf("http://%s/_conduit/info?apiKey=%s", serverAddr, apiKey)
|
||||||
|
req, _ := http.NewRequest("GET", url, nil)
|
||||||
|
req.Host = serverAddr
|
||||||
|
|
||||||
|
resp2, err := (&http.Client{Timeout: 5 * time.Second}).Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("info request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Body.Close()
|
||||||
|
|
||||||
|
if resp2.Header.Get("Content-Type") != "application/json" {
|
||||||
|
t.Errorf("expected application/json, got %q", resp2.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Large Body Tests ----------
|
||||||
|
|
||||||
|
func TestHTTPLargeResponseBody(t *testing.T) {
|
||||||
|
apiKey := "test-key-large-resp"
|
||||||
|
|
||||||
|
// Start Target That Returns a Large Body
|
||||||
|
largeBody := strings.Repeat("A", 1024*1024) // 1 MB
|
||||||
|
port := getFreePort(t)
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/large", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(largeBody))
|
||||||
|
})
|
||||||
|
srv := &http.Server{Addr: addr, Handler: mux}
|
||||||
|
go srv.ListenAndServe()
|
||||||
|
defer srv.Close()
|
||||||
|
waitForPort(t, addr, 3*time.Second)
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect Tunnel
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr), "large-resp", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Request Large Response
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "large-resp", "GET", "/large", "")
|
||||||
|
body := readBody(t, resp)
|
||||||
|
|
||||||
|
if len(body) != len(largeBody) {
|
||||||
|
t.Errorf("expected %d bytes, got %d", len(largeBody), len(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPLargeRequestBody(t *testing.T) {
|
||||||
|
apiKey := "test-key-large-req"
|
||||||
|
|
||||||
|
// Start Target That Echoes Body Size
|
||||||
|
port := getFreePort(t)
|
||||||
|
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, _ := io.ReadAll(r.Body)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, "size:%d", len(data))
|
||||||
|
})
|
||||||
|
srv := &http.Server{Addr: addr, Handler: mux}
|
||||||
|
go srv.ListenAndServe()
|
||||||
|
defer srv.Close()
|
||||||
|
waitForPort(t, addr, 3*time.Second)
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect Tunnel
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr), "large-req", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Send Large Request Body (512 KB)
|
||||||
|
largePayload := strings.Repeat("B", 512*1024)
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "large-req", "POST", "/upload", largePayload)
|
||||||
|
body := readBody(t, resp)
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("size:%d", len(largePayload))
|
||||||
|
if body != expected {
|
||||||
|
t.Errorf("expected %q, got %q", expected, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- TCP Tunnel Tests ----------
|
||||||
|
|
||||||
|
func TestTCPTunnelEcho(t *testing.T) {
|
||||||
|
apiKey := "test-key-tcp"
|
||||||
|
|
||||||
|
// Start TCP Echo Server
|
||||||
|
tcpAddr, stopTCP := startTCPEchoTarget(t)
|
||||||
|
defer stopTCP()
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect TCP Tunnel (bare host:port — the realistic way users would specify it)
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, tcpAddr, "tcp-test", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Send Raw HTTP Request Through Tunnel to TCP Echo
|
||||||
|
conn, err := net.DialTimeout("tcp", serverAddr, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Write a Raw HTTP Request (the TCP echo will bounce it back)
|
||||||
|
reqLine := fmt.Sprintf("GET / HTTP/1.1\r\nHost: tcp-test.%s\r\n\r\n", serverAddr)
|
||||||
|
_, err = conn.Write([]byte(reqLine))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read Echoed Data
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read echo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := string(buf[:n])
|
||||||
|
if !strings.Contains(response, "GET / HTTP/1.1") {
|
||||||
|
t.Errorf("expected echoed request, got: %q", response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPTunnelLargePayload(t *testing.T) {
|
||||||
|
apiKey := "test-key-tcp-large"
|
||||||
|
|
||||||
|
// Start TCP Echo Server
|
||||||
|
tcpAddr, stopTCP := startTCPEchoTarget(t)
|
||||||
|
defer stopTCP()
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect TCP Tunnel (bare host:port)
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, tcpAddr, "tcp-large", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Connect and Send Large Payload
|
||||||
|
conn, err := net.DialTimeout("tcp", serverAddr, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Route via Host Header Then Send 64 KB Payload
|
||||||
|
header := fmt.Sprintf("POST /data HTTP/1.1\r\nHost: tcp-large.%s\r\nContent-Length: 65536\r\n\r\n", serverAddr)
|
||||||
|
payload := header + strings.Repeat("X", 64*1024)
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte(payload))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read All Echoed Data
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
var received []byte
|
||||||
|
buf := make([]byte, 8192)
|
||||||
|
for len(received) < len(payload) {
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
received = append(received, buf[:n]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(received) != len(payload) {
|
||||||
|
t.Errorf("expected %d bytes echoed, got %d", len(payload), len(received))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Concurrency Tests ----------
|
||||||
|
|
||||||
|
func TestConcurrentHTTPRequests(t *testing.T) {
|
||||||
|
apiKey := "test-key-concurrent"
|
||||||
|
|
||||||
|
// Start Target HTTP Server
|
||||||
|
targetAddr, stopTarget := startHTTPTarget(t)
|
||||||
|
defer stopTarget()
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect Tunnel
|
||||||
|
stopTunnel := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", targetAddr), "conc-test", apiKey)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// Fire 20 Concurrent Requests
|
||||||
|
const numRequests = 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan string, numRequests)
|
||||||
|
|
||||||
|
for i := range numRequests {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
path := fmt.Sprintf("/item/%d", idx)
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "conc-test", "GET", path, "")
|
||||||
|
body := readBody(t, resp)
|
||||||
|
|
||||||
|
expected := fmt.Sprintf("echo: GET %s", path)
|
||||||
|
if !strings.Contains(body, expected) {
|
||||||
|
errors <- fmt.Sprintf("request %d: expected %q, got %q", idx, expected, body)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errors <- fmt.Sprintf("request %d: expected 200, got %d", idx, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
for errMsg := range errors {
|
||||||
|
t.Error(errMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentMultiTunnelRequests(t *testing.T) {
|
||||||
|
apiKey := "test-key-conc-multi"
|
||||||
|
|
||||||
|
// Start Two Target Servers
|
||||||
|
target1Addr, stopTarget1 := startHTTPTarget(t)
|
||||||
|
defer stopTarget1()
|
||||||
|
|
||||||
|
port2 := getFreePort(t)
|
||||||
|
addr2 := fmt.Sprintf("127.0.0.1:%d", port2)
|
||||||
|
mux2 := http.NewServeMux()
|
||||||
|
mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, "server-two: %s", r.URL.Path)
|
||||||
|
})
|
||||||
|
srv2 := &http.Server{Addr: addr2, Handler: mux2}
|
||||||
|
go srv2.ListenAndServe()
|
||||||
|
defer srv2.Close()
|
||||||
|
waitForPort(t, addr2, 3*time.Second)
|
||||||
|
|
||||||
|
// Start Conduit Server
|
||||||
|
serverAddr, stopServer := startConduitServer(t, apiKey)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Connect Two Tunnels
|
||||||
|
stopTunnel1 := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", target1Addr), "cm-one", apiKey)
|
||||||
|
defer stopTunnel1()
|
||||||
|
|
||||||
|
stopTunnel2 := connectTunnel(t, serverAddr, fmt.Sprintf("http://%s", addr2), "cm-two", apiKey)
|
||||||
|
defer stopTunnel2()
|
||||||
|
|
||||||
|
// Fire Concurrent Requests to Both Tunnels
|
||||||
|
const perTunnel = 10
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errors := make(chan string, perTunnel*2)
|
||||||
|
|
||||||
|
for i := range perTunnel {
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// Requests to Tunnel One
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
path := fmt.Sprintf("/a/%d", idx)
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "cm-one", "GET", path, "")
|
||||||
|
body := readBody(t, resp)
|
||||||
|
if !strings.Contains(body, fmt.Sprintf("echo: GET %s", path)) {
|
||||||
|
errors <- fmt.Sprintf("tunnel-one req %d: got %q", idx, body)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
|
||||||
|
// Requests to Tunnel Two
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
path := fmt.Sprintf("/b/%d", idx)
|
||||||
|
resp := sendHTTPViaTunnel(t, serverAddr, "cm-two", "GET", path, "")
|
||||||
|
body := readBody(t, resp)
|
||||||
|
if !strings.Contains(body, fmt.Sprintf("server-two: %s", path)) {
|
||||||
|
errors <- fmt.Sprintf("tunnel-two req %d: got %q", idx, body)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(errors)
|
||||||
|
|
||||||
|
for errMsg := range errors {
|
||||||
|
t.Error(errMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ http.ResponseWriter = (*rawHTTPResponseWriter)(nil)
|
|
||||||
|
|
||||||
type rawHTTPResponseWriter struct {
|
|
||||||
conn net.Conn
|
|
||||||
header http.Header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *rawHTTPResponseWriter) Header() http.Header {
|
|
||||||
if f.header == nil {
|
|
||||||
f.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return f.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *rawHTTPResponseWriter) Write(data []byte) (int, error) {
|
|
||||||
return f.conn.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
// Write Status
|
|
||||||
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
|
|
||||||
_, _ = f.conn.Write([]byte(status))
|
|
||||||
|
|
||||||
// Write Headers
|
|
||||||
for key, values := range f.header {
|
|
||||||
for _, value := range values {
|
|
||||||
_, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// End Headers
|
|
||||||
_, _ = f.conn.Write([]byte("\r\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *rawHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
// Return Raw Connection & ReadWriter
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
|
||||||
return f.conn, rw, nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
@@ -14,17 +13,17 @@ type reconstructedConn struct {
|
|||||||
reader io.Reader
|
reader io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads from the reconstructed reader (captured data + original conn).
|
// Read reads from the reconstructed reader (prepended data + original conn).
|
||||||
func (rc *reconstructedConn) Read(p []byte) (n int, err error) {
|
func (rc *reconstructedConn) Read(p []byte) (n int, err error) {
|
||||||
return rc.reader.Read(p)
|
return rc.reader.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newReconstructedConn creates a reconstructed connection that replays captured data
|
// newReconstructedConn creates a reconstructed connection that replays the provided
|
||||||
// before reading from the original connection.
|
// readers in order before reading from the underlying connection.
|
||||||
func newReconstructedConn(conn net.Conn, capturedData *bytes.Buffer) net.Conn {
|
func newReconstructedConn(conn net.Conn, readers ...io.Reader) net.Conn {
|
||||||
allReader := io.MultiReader(capturedData, conn)
|
allReaders := append(readers, conn)
|
||||||
return &reconstructedConn{
|
return &reconstructedConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
reader: allReader,
|
reader: io.MultiReader(allReaders...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
187
server/server.go
187
server/server.go
@@ -1,14 +1,11 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -62,37 +59,104 @@ func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
// Raw TCP Listener - This is necessary so we can conditionally either relay
|
// HTTP Server - Uses stdlib http.Server for proper HTTP response handling
|
||||||
// the raw TCP connection, or handle conduit control server API requests.
|
// including Content-Length, chunked encoding, and keep-alive semantics.
|
||||||
listener, err := net.Listen("tcp", s.cfg.BindAddress)
|
httpServer := &http.Server{
|
||||||
if err != nil {
|
Addr: s.cfg.BindAddress,
|
||||||
return err
|
Handler: s,
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
// Context Cancellation - Close the listener when the context is cancelled
|
// Context Cancellation - Gracefully shut down when the context is cancelled.
|
||||||
// so that Accept() unblocks and the loop exits cleanly.
|
|
||||||
go func() {
|
go func() {
|
||||||
<-s.ctx.Done()
|
<-s.ctx.Done()
|
||||||
listener.Close()
|
log.Info("conduit server shutting down")
|
||||||
|
httpServer.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Start Listening
|
// Start Server
|
||||||
log.Infof("conduit server listening on %s", s.cfg.BindAddress)
|
log.Infof("conduit server listening on %s", s.cfg.BindAddress)
|
||||||
for {
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
conn, err := listener.Accept()
|
return err
|
||||||
if err != nil {
|
}
|
||||||
// Expected Error on Shutdown
|
|
||||||
if s.ctx.Err() != nil {
|
|
||||||
log.Info("conduit server shutting down")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.WithError(err).Error("error accepting connection")
|
|
||||||
continue
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get True Host
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
r.RemoteAddr = xff
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.handleRawConnection(conn)
|
// Validate Host
|
||||||
|
if !strings.Contains(r.Host, s.host) {
|
||||||
|
http.Error(w, fmt.Sprintf("unknown host: %s", r.Host), http.StatusBadRequest)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract Subdomain
|
||||||
|
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||||
|
if strings.Count(tunnelName, ".") != 0 {
|
||||||
|
http.Error(w, fmt.Sprintf("cannot tunnel nested subdomains: %s", r.Host), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Control Endpoints
|
||||||
|
if tunnelName == "" {
|
||||||
|
s.handleAsHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Tunnel Requests
|
||||||
|
s.handleTunnelRequest(w, r, tunnelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tunnelName string) {
|
||||||
|
// Get Tunnel
|
||||||
|
conduitTunnel, exists := s.tunnels.Get(tunnelName)
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, fmt.Sprintf("unknown tunnel: %s", tunnelName), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack Connection - Take over the raw TCP connection from the HTTP server
|
||||||
|
// so we can forward the full request (including body) through the tunnel.
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "hijack not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, bufrw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("hijack failed: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-Serialize Request Headers - The HTTP server already consumed the request
|
||||||
|
// from the connection. We re-serialize it so the tunnel client receives a
|
||||||
|
// complete HTTP request to forward to the local target.
|
||||||
|
var reqBuf bytes.Buffer
|
||||||
|
fmt.Fprintf(&reqBuf, "%s %s %s\r\n", r.Method, r.RequestURI, r.Proto)
|
||||||
|
fmt.Fprintf(&reqBuf, "Host: %s\r\n", r.Host)
|
||||||
|
_ = r.Header.Write(&reqBuf)
|
||||||
|
reqBuf.WriteString("\r\n")
|
||||||
|
|
||||||
|
// Reconstruct Connection - Combine re-serialized headers with any buffered
|
||||||
|
// body data (from the hijacked reader) and the raw connection.
|
||||||
|
reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw)
|
||||||
|
|
||||||
|
// Create Stream
|
||||||
|
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||||
|
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source())
|
||||||
|
|
||||||
|
// Add Stream
|
||||||
|
if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil {
|
||||||
|
log.WithError(err).Error("failed to add stream")
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start Stream
|
||||||
|
conduitTunnel.StartStream(tunnelStream, streamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||||
@@ -122,79 +186,6 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
|||||||
_, _ = w.Write(d)
|
_, _ = w.Write(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRawConnection(conn net.Conn) {
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Capture Consumed Data - When determining where to route the request, we
|
|
||||||
// have to read the host headers. This requires reading from the buffer, so
|
|
||||||
// if we later decide to tunnel the TCP connection we need to reconstruct the
|
|
||||||
// data from the buffer.
|
|
||||||
var capturedData bytes.Buffer
|
|
||||||
teeReader := io.TeeReader(conn, &capturedData)
|
|
||||||
bufReader := bufio.NewReader(teeReader)
|
|
||||||
|
|
||||||
// Create HTTP Request & Writer
|
|
||||||
w := &rawHTTPResponseWriter{conn: conn}
|
|
||||||
r, err := http.ReadRequest(bufReader)
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
// Validate Host
|
|
||||||
if !strings.Contains(r.Host, s.host) {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract Subdomain
|
|
||||||
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
|
||||||
if strings.Count(tunnelName, ".") != 0 {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get True Host
|
|
||||||
remoteHost := conn.RemoteAddr().String()
|
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
||||||
remoteHost = xff
|
|
||||||
}
|
|
||||||
r.RemoteAddr = remoteHost
|
|
||||||
|
|
||||||
// Handle Control Endpoints
|
|
||||||
if tunnelName == "" {
|
|
||||||
s.handleAsHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle Tunnels
|
|
||||||
conduitTunnel, exists := s.tunnels.Get(tunnelName)
|
|
||||||
if !exists {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Stream
|
|
||||||
reconstructedConn := newReconstructedConn(conn, &capturedData)
|
|
||||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
|
||||||
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source())
|
|
||||||
|
|
||||||
// Add Stream
|
|
||||||
if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, _ = fmt.Fprintf(w, "failed to add stream: %v", err)
|
|
||||||
log.WithError(err).Error("failed to add stream")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start Stream
|
|
||||||
conduitTunnel.StartStream(tunnelStream, streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Authorize Control Endpoints
|
// Authorize Control Endpoints
|
||||||
apiKey := r.URL.Query().Get("apiKey")
|
apiKey := r.URL.Query().Get("apiKey")
|
||||||
@@ -219,15 +210,13 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Get Tunnel Name
|
// Get Tunnel Name
|
||||||
tunnelName := r.URL.Query().Get("tunnelName")
|
tunnelName := r.URL.Query().Get("tunnelName")
|
||||||
if tunnelName == "" {
|
if tunnelName == "" {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
http.Error(w, "Missing tunnelName parameter", http.StatusBadRequest)
|
||||||
_, _ = w.Write([]byte("Missing tunnelName parameter"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate Unique
|
// Validate Unique
|
||||||
if _, exists := s.tunnels.Get(tunnelName); exists {
|
if _, exists := s.tunnels.Get(tunnelName); exists {
|
||||||
w.WriteHeader(http.StatusConflict)
|
http.Error(w, "Tunnel already registered", http.StatusConflict)
|
||||||
_, _ = w.Write([]byte("Tunnel already registered"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,23 +21,15 @@ type Forwarder interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||||
// Get Target URL
|
// Only parse as URL for HTTP targets. Bare host:port (e.g., "127.0.0.1:5432")
|
||||||
|
// is not a valid URL and should be treated as a raw TCP target.
|
||||||
targetURL, err := url.Parse(target)
|
targetURL, err := url.Parse(target)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get Connection Builder
|
|
||||||
var forwarder Forwarder
|
|
||||||
switch targetURL.Scheme {
|
switch targetURL.Scheme {
|
||||||
case "http", "https":
|
case "http", "https":
|
||||||
forwarder, err = newHTTPForwarder(targetURL, tunnelStore)
|
return newHTTPForwarder(targetURL, tunnelStore)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
forwarder = newTCPForwarder(target, tunnelStore)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return forwarder, nil
|
return newTCPForwarder(target, tunnelStore), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user