From 2e736897624e8d529f8e441deba123b1dbfbb594 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Mon, 22 Sep 2025 23:04:15 -0400 Subject: [PATCH] http vs tcp tunnel --- client/client.go | 2 +- cmd/tunnel.go | 3 -- tunnel/http.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++ tunnel/tunnel.go | 39 ++++++++++++++---- 4 files changed, 135 insertions(+), 13 deletions(-) create mode 100644 tunnel/http.go diff --git a/client/client.go b/client/client.go index 1e0aeec..c9258f1 100644 --- a/client/client.go +++ b/client/client.go @@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) { return nil, fmt.Errorf("failed to connect: %v", err) } - return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn), nil + return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn) } diff --git a/cmd/tunnel.go b/cmd/tunnel.go index 48ace87..1643366 100644 --- a/cmd/tunnel.go +++ b/cmd/tunnel.go @@ -22,9 +22,6 @@ var tunnelCmd = &cobra.Command{ if err != nil { log.Fatal("failed to create tunnel:", err) } - - // Start Tunnel - log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget) tunnel.Start() }, } diff --git a/tunnel/http.go b/tunnel/http.go new file mode 100644 index 0000000..dfb90e3 --- /dev/null +++ b/tunnel/http.go @@ -0,0 +1,104 @@ +package tunnel + +import ( + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" +) + +func HTTPConnectionBuilder(targetURL *url.URL) (ConnBuilder, error) { + multiConnListener := newMultiConnListener() + + // Create Reverse Proxy + proxy := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.Host = targetURL.Host + req.URL.Host = targetURL.Host + req.URL.Scheme = targetURL.Scheme + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway) + }, + } + + // Start HTTP Proxy + go func() { + defer multiConnListener.Close() + _ = http.Serve(multiConnListener, proxy) + }() + + // Return Connection Builder + return func() (conn io.ReadWriteCloser, err error) { + clientConn, serverConn := net.Pipe() + + if err := multiConnListener.addConn(serverConn); err != nil { + _ = clientConn.Close() + _ = serverConn.Close() + return nil, err + } + + return clientConn, nil + }, nil +} + +type multiConnListener struct { + connCh chan net.Conn + closed chan struct{} + once sync.Once +} + +func newMultiConnListener() *multiConnListener { + return &multiConnListener{ + connCh: make(chan net.Conn, 100), + closed: make(chan struct{}), + } +} + +func (l *multiConnListener) Accept() (net.Conn, error) { + select { + case conn := <-l.connCh: + if conn == nil { + return nil, fmt.Errorf("listener closed") + } + return conn, nil + case <-l.closed: + return nil, fmt.Errorf("listener closed") + } +} + +func (l *multiConnListener) Close() error { + l.once.Do(func() { + close(l.closed) + // Drain any remaining connections + go func() { + for conn := range l.connCh { + if conn != nil { + conn.Close() + } + } + }() + close(l.connCh) + }) + return nil +} + +func (l *multiConnListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (l *multiConnListener) addConn(conn net.Conn) error { + select { + case l.connCh <- conn: + return nil + case <-l.closed: + conn.Close() + return fmt.Errorf("listener is closed") + default: + conn.Close() + return fmt.Errorf("connection queue full") + } +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 785a9c8..19d2893 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "net/url" "sync" "github.com/gorilla/websocket" @@ -11,6 +12,8 @@ import ( "reichard.io/conduit/types" ) +type ConnBuilder func() (conn io.ReadWriteCloser, err error) + func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { return &Tunnel{ name: name, @@ -19,22 +22,40 @@ func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { } } -func NewClientTunnel(name, target string, wsConn *websocket.Conn) *Tunnel { - return &Tunnel{ - name: name, - wsConn: wsConn, - streams: make(map[string]io.ReadWriteCloser), - connBuilder: func() (io.ReadWriteCloser, error) { - return net.Dial("tcp", target) - }, +func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) { + targetURL, err := url.Parse(target) + if err != nil { + return nil, err } + + var connBuilder ConnBuilder + switch targetURL.Scheme { + case "http", "https": + log.Infof("creating HTTP tunnel: %s -> %s", name, target) + connBuilder, err = HTTPConnectionBuilder(targetURL) + if err != nil { + return nil, err + } + default: + log.Infof("creating TCP tunnel: %s -> %s", name, target) + connBuilder = func() (conn io.ReadWriteCloser, err error) { + return net.Dial("tcp", target) + } + } + + return &Tunnel{ + name: name, + wsConn: wsConn, + streams: make(map[string]io.ReadWriteCloser), + connBuilder: connBuilder, + }, nil } type Tunnel struct { name string wsConn *websocket.Conn streams map[string]io.ReadWriteCloser - connBuilder func() (io.ReadWriteCloser, error) + connBuilder ConnBuilder wsMu, streamsMu sync.Mutex }