From 20c1388cf46e1f56f813ac419827fb30b81ec670 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Tue, 23 Sep 2025 09:24:09 -0400 Subject: [PATCH] chore: better source tracking --- client/client.go | 2 +- server/server.go | 25 ++++++++++++++++--------- tunnel/tunnel.go | 28 ++++++++++++++++++---------- types/message.go | 7 ++++--- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/client/client.go b/client/client.go index c9258f1..b8cad0a 100644 --- a/client/client.go +++ b/client/client.go @@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) { return nil, fmt.Errorf("failed to connect: %v", err) } - return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn) + return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn) } diff --git a/server/server.go b/server/server.go index 57abd20..5d5ec46 100644 --- a/server/server.go +++ b/server/server.go @@ -135,24 +135,31 @@ func (s *Server) handleRawConnection(conn net.Conn) { } // Extract Subdomain - subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") - if strings.Count(subdomain, ".") != 0 { + tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") + if strings.Count(tunnelName, ".") != 0 { w.WriteHeader(http.StatusBadRequest) _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) return } + // Get True Host + remoteHost := conn.RemoteAddr().String() + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + remoteHost = xff + } + r.RemoteAddr = remoteHost + // Handle Control Endpoints - if subdomain == "" { + if tunnelName == "" { s.handleAsHTTP(w, r) return } // Handle Tunnels - conduitTunnel, exists := s.tunnels.Get(subdomain) + conduitTunnel, exists := s.tunnels.Get(tunnelName) if !exists { w.WriteHeader(http.StatusNotFound) - _, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain) + _, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName) return } @@ -165,8 +172,8 @@ func (s *Server) handleRawConnection(conn net.Conn) { return } - log.Infof("relaying %s to tunnel", subdomain) - _ = conduitTunnel.StartStream(streamID) + log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr) + _ = conduitTunnel.StartStream(streamID, r.RemoteAddr) } func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { @@ -215,7 +222,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Create Tunnel conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) s.tunnels.Set(tunnelName, conduitTunnel) - log.Infof("tunnel established: %s", tunnelName) + log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr) // Start Tunnel - This is blocking conduitTunnel.Start() @@ -223,5 +230,5 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Cleanup Tunnel s.tunnels.Delete(tunnelName) _ = wsConn.Close() - log.Infof("tunnel closed: %s", tunnelName) + log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr) } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 829d48b..cf54ab8 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -23,22 +23,28 @@ func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { } } -func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) { +func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) { + // Get Target URL targetURL, err := url.Parse(target) if err != nil { return nil, err } + // Derive Conduit URL + conduitURL := *serverURL + conduitURL.Host = name + "." + conduitURL.Host + + // Get Connection Builder var connBuilder ConnBuilder switch targetURL.Scheme { case "http", "https": - log.Infof("creating HTTP tunnel: %s -> %s", name, target) + log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target) connBuilder, err = HTTPConnectionBuilder(targetURL) if err != nil { return nil, err } default: - log.Infof("creating TCP tunnel: %s -> %s", name, target) + log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target) connBuilder = func() (conn io.ReadWriteCloser, err error) { return net.Dial("tcp", target) } @@ -109,7 +115,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error { return err } - go t.StartStream(streamID) + go t.StartStream(streamID, "") return nil } @@ -121,7 +127,7 @@ func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error { return nil } -func (t *Tunnel) StartStream(streamID string) error { +func (t *Tunnel) StartStream(streamID string, sourceAddr string) error { // Get Stream conn, found := t.streams.Get(streamID) if !found { @@ -131,8 +137,9 @@ func (t *Tunnel) StartStream(streamID string) error { // Close Stream defer func() { _ = t.sendWS(&types.Message{ - Type: types.MessageTypeClose, - StreamID: streamID, + Type: types.MessageTypeClose, + StreamID: streamID, + SourceAddr: sourceAddr, }) t.CloseStream(streamID) @@ -147,9 +154,10 @@ func (t *Tunnel) StartStream(streamID string) error { } if err := t.sendWS(&types.Message{ - Type: types.MessageTypeData, - Data: buffer[:n], - StreamID: streamID, + Type: types.MessageTypeData, + StreamID: streamID, + Data: buffer[:n], + SourceAddr: sourceAddr, }); err != nil { return err } diff --git a/types/message.go b/types/message.go index 3d54160..a3993f9 100644 --- a/types/message.go +++ b/types/message.go @@ -8,7 +8,8 @@ const ( ) type Message struct { - Type MessageType `json:"type"` - StreamID string `json:"stream_id"` - Data []byte `json:"data,omitempty"` + Type MessageType `json:"type"` + StreamID string `json:"stream_id"` + SourceAddr string `json:"source_addr"` + Data []byte `json:"data,omitempty"` }