From f5741ef60b4d1b8362af7d863f97d28b53e97f11 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sun, 21 Sep 2025 13:14:45 -0400 Subject: [PATCH] wip 1 --- ...{writer.go => raw_http_response_writer.go} | 12 +- server/reconstructed_conn.go | 30 ++++ server/server.go | 151 +++--------------- tunnel/tunnel.go | 107 +++++++++++++ 4 files changed, 169 insertions(+), 131 deletions(-) rename server/{writer.go => raw_http_response_writer.go} (66%) create mode 100644 server/reconstructed_conn.go create mode 100644 tunnel/tunnel.go diff --git a/server/writer.go b/server/raw_http_response_writer.go similarity index 66% rename from server/writer.go rename to server/raw_http_response_writer.go index 601cc40..79babcc 100644 --- a/server/writer.go +++ b/server/raw_http_response_writer.go @@ -7,25 +7,25 @@ import ( "net/http" ) -var _ http.ResponseWriter = (*connResponseWriter)(nil) +var _ http.ResponseWriter = (*rawHTTPResponseWriter)(nil) -type connResponseWriter struct { +type rawHTTPResponseWriter struct { conn net.Conn header http.Header } -func (f *connResponseWriter) Header() http.Header { +func (f *rawHTTPResponseWriter) Header() http.Header { if f.header == nil { f.header = make(http.Header) } return f.header } -func (f *connResponseWriter) Write(data []byte) (int, error) { +func (f *rawHTTPResponseWriter) Write(data []byte) (int, error) { return f.conn.Write(data) } -func (f *connResponseWriter) WriteHeader(statusCode int) { +func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) { // Write Status status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) _, _ = f.conn.Write([]byte(status)) @@ -41,7 +41,7 @@ func (f *connResponseWriter) WriteHeader(statusCode int) { _, _ = f.conn.Write([]byte("\r\n")) } -func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (f *rawHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Return Raw Connection & ReadWriter rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) return f.conn, rw, nil diff --git a/server/reconstructed_conn.go b/server/reconstructed_conn.go new file mode 100644 index 0000000..5993c69 --- /dev/null +++ b/server/reconstructed_conn.go @@ -0,0 +1,30 @@ +package server + +import ( + "bytes" + "io" + "net" +) + +var _ io.ReadWriteCloser = (*reconstructedConn)(nil) + +// reconstructedConn wraps a net.Conn and overrides Read to handle captured data. +type reconstructedConn struct { + net.Conn + reader io.Reader +} + +// Read reads from the reconstructed reader (captured data + original conn). +func (rc *reconstructedConn) Read(p []byte) (n int, err error) { + return rc.reader.Read(p) +} + +// newReconstructedConn creates a reconstructed connection that replays captured data +// before reading from the original connection. +func newReconstructedConn(conn net.Conn, capturedData *bytes.Buffer) net.Conn { + allReader := io.MultiReader(capturedData, conn) + return &reconstructedConn{ + Conn: conn, + reader: allReader, + } +} diff --git a/server/server.go b/server/server.go index 197ef4a..6f465b8 100644 --- a/server/server.go +++ b/server/server.go @@ -17,7 +17,7 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "reichard.io/conduit/config" - "reichard.io/conduit/types" + "reichard.io/conduit/tunnel" ) type InfoResponse struct { @@ -30,19 +30,13 @@ type TunnelInfo struct { Target string `json:"target"` } -type TunnelConnection struct { - *websocket.Conn - name string - streams map[string]chan []byte -} - type Server struct { host string cfg *config.ServerConfig mu sync.RWMutex upgrader websocket.Upgrader - tunnels map[string]*TunnelConnection + tunnels map[string]*tunnel.Tunnel } func NewServer(cfg *config.ServerConfig) (*Server, error) { @@ -56,7 +50,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) { return &Server{ cfg: cfg, host: serverURL.Host, - tunnels: make(map[string]*TunnelConnection), + tunnels: make(map[string]*tunnel.Tunnel), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -94,7 +88,7 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { for t, c := range s.tunnels { allTunnels = append(allTunnels, TunnelInfo{ Name: t, - Target: c.RemoteAddr().String(), + Target: c.Source(), }) } s.mu.RUnlock() @@ -114,63 +108,6 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write(d) } -func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) { - defer clientConn.Close() - - // Create Identifiers - streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) - responseChan := make(chan []byte, 100) - - // Register Stream - s.mu.Lock() - if tunnelConn.streams == nil { - tunnelConn.streams = make(map[string]chan []byte) - } - tunnelConn.streams[streamID] = responseChan - s.mu.Unlock() - - // Clean Up - defer func() { - s.mu.Lock() - delete(tunnelConn.streams, streamID) - close(responseChan) - s.mu.Unlock() - - // Send Close - closeMsg := types.Message{ - Type: types.MessageTypeClose, - StreamID: streamID, - } - _ = tunnelConn.WriteJSON(closeMsg) - }() - - // Read & Send Chunks - go func() { - buffer := make([]byte, 4096) - for { - n, err := dataReader.Read(buffer) - if err != nil { - return - } - - if err := tunnelConn.WriteJSON(types.Message{ - Type: types.MessageTypeData, - StreamID: streamID, - Data: buffer[:n], - }); err != nil { - return - } - } - }() - - // Return Response Data - for data := range responseChan { - if _, err := clientConn.Write(data); err != nil { - break - } - } -} - func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() @@ -183,7 +120,7 @@ func (s *Server) handleRawConnection(conn net.Conn) { bufReader := bufio.NewReader(teeReader) // Create HTTP Request & Writer - w := &connResponseWriter{conn: conn} + w := &rawHTTPResponseWriter{conn: conn} r, err := http.ReadRequest(bufReader) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -216,13 +153,17 @@ func (s *Server) handleRawConnection(conn net.Conn) { s.mu.RLock() tunnelConn, exists := s.tunnels[subdomain] s.mu.RUnlock() - if exists { - log.Infof("relaying %s to tunnel", subdomain) - - // Reconstruct Data & Proxy Connection - allReader := io.MultiReader(&capturedData, r.Body) - s.proxyRawConnection(conn, tunnelConn, allReader) + if !exists { + w.WriteHeader(http.StatusNotFound) + _, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain) + return } + + // Initialize New Stream + log.Infof("relaying %s to tunnel", subdomain) + reconstructedConn := newReconstructedConn(conn, &capturedData) + streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) + tunnelConn.NewStream(streamID, reconstructedConn) } func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { @@ -245,40 +186,6 @@ func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { - for { - var msg types.Message - err := tunnel.ReadJSON(&msg) - if err != nil { - return - } - - if msg.StreamID == "" { - log.Infof("tunnel %s missing streamID", tunnel.name) - continue - } - - switch msg.Type { - case types.MessageTypeClose: - return - case types.MessageTypeData: - s.mu.RLock() - streamChan, exists := tunnel.streams[msg.StreamID] - if !exists { - log.Infof("stream %s does not exist", msg.StreamID) - s.mu.RUnlock() - continue - } - - select { - case streamChan <- msg.Data: - case <-time.After(time.Second): - log.Warnf("stream %s channel full, dropping data", msg.StreamID) - } - s.mu.RUnlock() - } - } -} func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Get Tunnel Name tunnelName := r.URL.Query().Get("tunnelName") @@ -302,26 +209,20 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { return } - // Create & Cache TunnelConnection - tunnel := &TunnelConnection{ - Conn: wsConn, - name: tunnelName, - streams: make(map[string]chan []byte), - } + // Create Tunnel + conduitTunnel := tunnel.NewTunnel(tunnelName, wsConn) s.mu.Lock() - s.tunnels[tunnelName] = tunnel + s.tunnels[tunnelName] = conduitTunnel s.mu.Unlock() log.Infof("tunnel established: %s", tunnelName) - // Keep connection alive and handle cleanup - defer func() { - s.mu.Lock() - delete(s.tunnels, tunnelName) - s.mu.Unlock() - _ = wsConn.Close() - log.Infof("tunnel closed: %s", tunnelName) - }() + // Start Tunnel - This is blocking + conduitTunnel.Start() - // Handle tunnel messages - s.handleTunnelMessages(tunnel) + // Cleanup Tunnel + s.mu.Lock() + delete(s.tunnels, tunnelName) + s.mu.Unlock() + _ = wsConn.Close() + log.Infof("tunnel closed: %s", tunnelName) } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go new file mode 100644 index 0000000..e2eeebb --- /dev/null +++ b/tunnel/tunnel.go @@ -0,0 +1,107 @@ +package tunnel + +import ( + "io" + "sync" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + "reichard.io/conduit/types" +) + +func NewTunnel(name string, wsConn *websocket.Conn) *Tunnel { + return &Tunnel{ + name: name, + wsConn: wsConn, + streams: make(map[string]io.ReadWriteCloser), + } +} + +type Tunnel struct { + name string + wsConn *websocket.Conn + streams map[string]io.ReadWriteCloser + + wsMu, streamsMu sync.Mutex +} + +// Start starts the tunnel and is the primary loop that handles all websocket messages. +// Messages are relayed to the local stream. +func (t *Tunnel) Start() { + for { + var msg types.Message + err := t.wsConn.ReadJSON(&msg) + if err != nil { + return + } + + if msg.StreamID == "" { + log.Warnf("tunnel %s missing streamID", t.name) + continue + } + + switch msg.Type { + case types.MessageTypeClose: + t.CloseStream(msg.StreamID) + case types.MessageTypeData: + t.WriteStream(msg.StreamID, msg.Data) + } + } +} + +func (t *Tunnel) NewStream(streamID string, localConn io.ReadWriteCloser) { + t.streamsMu.Lock() + t.streams[streamID] = localConn + t.streamsMu.Unlock() + + defer t.CloseStream(streamID) + buffer := make([]byte, 4096) + for { + n, err := localConn.Read(buffer) + if err != nil { + return + } + + if err := t.sendWS(&types.Message{ + Type: types.MessageTypeData, + Data: buffer[:n], + StreamID: streamID, + }); err != nil { + return + } + } +} + +func (t *Tunnel) WriteStream(streamID string, data []byte) { + t.streamsMu.Lock() + defer t.streamsMu.Unlock() + if localConn, ok := t.streams[streamID]; ok { + _, _ = localConn.Write(data) + } else { + log.Infof("stream %s does not exist", streamID) + } +} + +func (t *Tunnel) CloseStream(streamID string) { + _ = t.sendWS(&types.Message{ + Type: types.MessageTypeClose, + StreamID: streamID, + }) + + t.streamsMu.Lock() + defer t.streamsMu.Unlock() + if localConn, ok := t.streams[streamID]; ok { + delete(t.streams, streamID) + _ = localConn.Close() + } +} + +func (t *Tunnel) Source() string { + return t.wsConn.RemoteAddr().String() +} + +func (t *Tunnel) sendWS(msg *types.Message) error { + t.wsMu.Lock() + defer t.wsMu.Unlock() + return t.wsConn.WriteJSON(msg) +}