chore: better source tracking
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Evan Reichard 2025-09-23 09:24:09 -04:00
parent 0333680a2b
commit 20c1388cf4
4 changed files with 39 additions and 23 deletions

View File

@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
return nil, fmt.Errorf("failed to connect: %v", err) return nil, fmt.Errorf("failed to connect: %v", err)
} }
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn) return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn)
} }

View File

@ -135,24 +135,31 @@ func (s *Server) handleRawConnection(conn net.Conn) {
} }
// Extract Subdomain // Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 { if strings.Count(tunnelName, ".") != 0 {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return return
} }
// Get True Host
remoteHost := conn.RemoteAddr().String()
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
remoteHost = xff
}
r.RemoteAddr = remoteHost
// Handle Control Endpoints // Handle Control Endpoints
if subdomain == "" { if tunnelName == "" {
s.handleAsHTTP(w, r) s.handleAsHTTP(w, r)
return return
} }
// Handle Tunnels // Handle Tunnels
conduitTunnel, exists := s.tunnels.Get(subdomain) conduitTunnel, exists := s.tunnels.Get(tunnelName)
if !exists { if !exists {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain) _, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
return return
} }
@ -165,8 +172,8 @@ func (s *Server) handleRawConnection(conn net.Conn) {
return return
} }
log.Infof("relaying %s to tunnel", subdomain) log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr)
_ = conduitTunnel.StartStream(streamID) _ = conduitTunnel.StartStream(streamID, r.RemoteAddr)
} }
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
@ -215,7 +222,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Create Tunnel // Create Tunnel
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
s.tunnels.Set(tunnelName, conduitTunnel) s.tunnels.Set(tunnelName, conduitTunnel)
log.Infof("tunnel established: %s", tunnelName) log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr)
// Start Tunnel - This is blocking // Start Tunnel - This is blocking
conduitTunnel.Start() conduitTunnel.Start()
@ -223,5 +230,5 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Cleanup Tunnel // Cleanup Tunnel
s.tunnels.Delete(tunnelName) s.tunnels.Delete(tunnelName)
_ = wsConn.Close() _ = wsConn.Close()
log.Infof("tunnel closed: %s", tunnelName) log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr)
} }

View File

@ -23,22 +23,28 @@ func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
} }
} }
func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) { func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
// Get Target URL
targetURL, err := url.Parse(target) targetURL, err := url.Parse(target)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Derive Conduit URL
conduitURL := *serverURL
conduitURL.Host = name + "." + conduitURL.Host
// Get Connection Builder
var connBuilder ConnBuilder var connBuilder ConnBuilder
switch targetURL.Scheme { switch targetURL.Scheme {
case "http", "https": case "http", "https":
log.Infof("creating HTTP tunnel: %s -> %s", name, target) log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
connBuilder, err = HTTPConnectionBuilder(targetURL) connBuilder, err = HTTPConnectionBuilder(targetURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
default: default:
log.Infof("creating TCP tunnel: %s -> %s", name, target) log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
connBuilder = func() (conn io.ReadWriteCloser, err error) { connBuilder = func() (conn io.ReadWriteCloser, err error) {
return net.Dial("tcp", target) return net.Dial("tcp", target)
} }
@ -109,7 +115,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
return err return err
} }
go t.StartStream(streamID) go t.StartStream(streamID, "")
return nil return nil
} }
@ -121,7 +127,7 @@ func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
return nil return nil
} }
func (t *Tunnel) StartStream(streamID string) error { func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
// Get Stream // Get Stream
conn, found := t.streams.Get(streamID) conn, found := t.streams.Get(streamID)
if !found { if !found {
@ -131,8 +137,9 @@ func (t *Tunnel) StartStream(streamID string) error {
// Close Stream // Close Stream
defer func() { defer func() {
_ = t.sendWS(&types.Message{ _ = t.sendWS(&types.Message{
Type: types.MessageTypeClose, Type: types.MessageTypeClose,
StreamID: streamID, StreamID: streamID,
SourceAddr: sourceAddr,
}) })
t.CloseStream(streamID) t.CloseStream(streamID)
@ -147,9 +154,10 @@ func (t *Tunnel) StartStream(streamID string) error {
} }
if err := t.sendWS(&types.Message{ if err := t.sendWS(&types.Message{
Type: types.MessageTypeData, Type: types.MessageTypeData,
Data: buffer[:n], StreamID: streamID,
StreamID: streamID, Data: buffer[:n],
SourceAddr: sourceAddr,
}); err != nil { }); err != nil {
return err return err
} }

View File

@ -8,7 +8,8 @@ const (
) )
type Message struct { type Message struct {
Type MessageType `json:"type"` Type MessageType `json:"type"`
StreamID string `json:"stream_id"` StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"` SourceAddr string `json:"source_addr"`
Data []byte `json:"data,omitempty"`
} }