chore: better source tracking
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Evan Reichard 2025-09-23 09:24:09 -04:00
parent 0333680a2b
commit 20c1388cf4
4 changed files with 39 additions and 23 deletions

View File

@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
return nil, fmt.Errorf("failed to connect: %v", err)
}
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn)
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn)
}

View File

@ -135,24 +135,31 @@ func (s *Server) handleRawConnection(conn net.Conn) {
}
// Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 {
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(tunnelName, ".") != 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return
}
// Get True Host
remoteHost := conn.RemoteAddr().String()
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
remoteHost = xff
}
r.RemoteAddr = remoteHost
// Handle Control Endpoints
if subdomain == "" {
if tunnelName == "" {
s.handleAsHTTP(w, r)
return
}
// Handle Tunnels
conduitTunnel, exists := s.tunnels.Get(subdomain)
conduitTunnel, exists := s.tunnels.Get(tunnelName)
if !exists {
w.WriteHeader(http.StatusNotFound)
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain)
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
return
}
@ -165,8 +172,8 @@ func (s *Server) handleRawConnection(conn net.Conn) {
return
}
log.Infof("relaying %s to tunnel", subdomain)
_ = conduitTunnel.StartStream(streamID)
log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr)
_ = conduitTunnel.StartStream(streamID, r.RemoteAddr)
}
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
@ -215,7 +222,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Create Tunnel
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
s.tunnels.Set(tunnelName, conduitTunnel)
log.Infof("tunnel established: %s", tunnelName)
log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr)
// Start Tunnel - This is blocking
conduitTunnel.Start()
@ -223,5 +230,5 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Cleanup Tunnel
s.tunnels.Delete(tunnelName)
_ = wsConn.Close()
log.Infof("tunnel closed: %s", tunnelName)
log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr)
}

View File

@ -23,22 +23,28 @@ func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
}
}
func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) {
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
// Get Target URL
targetURL, err := url.Parse(target)
if err != nil {
return nil, err
}
// Derive Conduit URL
conduitURL := *serverURL
conduitURL.Host = name + "." + conduitURL.Host
// Get Connection Builder
var connBuilder ConnBuilder
switch targetURL.Scheme {
case "http", "https":
log.Infof("creating HTTP tunnel: %s -> %s", name, target)
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
connBuilder, err = HTTPConnectionBuilder(targetURL)
if err != nil {
return nil, err
}
default:
log.Infof("creating TCP tunnel: %s -> %s", name, target)
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
connBuilder = func() (conn io.ReadWriteCloser, err error) {
return net.Dial("tcp", target)
}
@ -109,7 +115,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
return err
}
go t.StartStream(streamID)
go t.StartStream(streamID, "")
return nil
}
@ -121,7 +127,7 @@ func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
return nil
}
func (t *Tunnel) StartStream(streamID string) error {
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
// Get Stream
conn, found := t.streams.Get(streamID)
if !found {
@ -131,8 +137,9 @@ func (t *Tunnel) StartStream(streamID string) error {
// Close Stream
defer func() {
_ = t.sendWS(&types.Message{
Type: types.MessageTypeClose,
StreamID: streamID,
Type: types.MessageTypeClose,
StreamID: streamID,
SourceAddr: sourceAddr,
})
t.CloseStream(streamID)
@ -147,9 +154,10 @@ func (t *Tunnel) StartStream(streamID string) error {
}
if err := t.sendWS(&types.Message{
Type: types.MessageTypeData,
Data: buffer[:n],
StreamID: streamID,
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
SourceAddr: sourceAddr,
}); err != nil {
return err
}

View File

@ -8,7 +8,8 @@ const (
)
type Message struct {
Type MessageType `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
Type MessageType `json:"type"`
StreamID string `json:"stream_id"`
SourceAddr string `json:"source_addr"`
Data []byte `json:"data,omitempty"`
}