diff --git a/cmd/link.go b/cmd/link.go index 5fd622d..e3f75c2 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -40,11 +40,6 @@ func init() { linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address") } -type TunnelConfig struct { - // The conduit server address, e.g. https://conduit.example.com - ServerAddress string `default:"http://localhost:8080"` -} - func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) { // Parse Server URL serverURL, err := url.Parse(serverAddress) @@ -73,6 +68,7 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) return &Tunnel{ name: tunnelName, target: tunnelTarget, + serverURL: serverURL, serverConn: serverConn, localConns: make(map[string]net.Conn), }, nil @@ -80,8 +76,9 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) } type Tunnel struct { - name string - target string + name string + target string + serverURL *url.URL serverConn *websocket.Conn localConns map[string]net.Conn @@ -89,7 +86,7 @@ type Tunnel struct { } func (t *Tunnel) Start() error { - log.Infof("TCP Tunnel active! %s.example.com -> %s\n", t.name, t.target) + log.Infof("TCP Tunnel active! %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target) defer t.serverConn.Close() // Handle Messages @@ -156,17 +153,17 @@ func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { return localConn, nil } -func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) { +func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) { defer func() { t.mu.Lock() delete(t.localConns, streamID) t.mu.Unlock() - lConn.Close() + localConn.Close() }() buffer := make([]byte, 4096) for { - n, err := lConn.Read(buffer) + n, err := localConn.Read(buffer) if err != nil { break } diff --git a/flake.lock b/flake.lock index cbf5f41..fa2eda9 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1754292888, - "narHash": "sha256-1ziydHSiDuSnaiPzCQh1mRFBsM2d2yRX9I+5OPGEmIE=", + "lastModified": 1758216857, + "narHash": "sha256-h1BW2y7CY4LI9w61R02wPaOYfmYo82FyRqHIwukQ6SY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "ce01daebf8489ba97bd1609d185ea276efdeb121", + "rev": "d2ed99647a4b195f0bcc440f76edfa10aeb3b743", "type": "github" }, "original": { diff --git a/server/server.go b/server/server.go index d066646..9039ec9 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "fmt" + "io" "net" "net/http" "strings" @@ -23,8 +24,8 @@ type TunnelConnection struct { type Server struct { tunnels map[string]*TunnelConnection - mu sync.RWMutex upgrader websocket.Upgrader + mu sync.RWMutex } func NewServer() *Server { @@ -59,11 +60,21 @@ func (s *Server) Start(addr string) error { } } -func (s *Server) extractSubdomain(host string) string { +func (s *Server) extractSubdomain(peakReader io.Reader) string { + // Read Request + req, err := http.ReadRequest(bufio.NewReader(peakReader)) + if err != nil { + return "" + } + defer req.Body.Close() + + // Extract Host + host := req.Host if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } + // Extract Subdomain parts := strings.Split(host, ".") if len(parts) >= 1 { return parts[0] @@ -72,57 +83,26 @@ func (s *Server) extractSubdomain(host string) string { return "" } -func (s *Server) getStatus(conn net.Conn) { +func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) { s.mu.RLock() count := len(s.tunnels) s.mu.RUnlock() - response := fmt.Sprintf( - "HTTP/1.1 200 OK\r\n"+ - "Content-Type: application/json\r\n"+ - "Content-Length: %d\r\n\r\n"+ - `{"tunnels": %d}`, - len(fmt.Sprintf(`{"tunnels": %d}`, count)), count) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) - conn.Write([]byte(response)) + response := fmt.Sprintf(`{"tunnels": %d}`, count) + _, _ = w.Write([]byte(response)) } -func (s *Server) extractHostFromHTTP(data []byte) string { - // Simple HTTP header parsing - lines := strings.Split(string(data), "\r\n") - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "host:") { - parts := strings.SplitN(line, ":", 2) - if len(parts) == 2 { - return strings.TrimSpace(parts[1]) - } - } - } - return "" -} - -func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, initialData []byte) { +func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) { defer clientConn.Close() - // Generate a unique stream ID for this connection + // Create Identifiers streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) - - // Send initial data with stream ID - msg := types.Message{ - Type: "data", - StreamID: streamID, - Data: initialData, - } - - if err := tunnelConn.WriteJSON(msg); err != nil { - log.Errorf("Error sending initial data: %v", err) - return - } - - // Create a channel for this stream's responses responseChan := make(chan []byte, 100) - // Register this stream + // Register Stream s.mu.Lock() if tunnelConn.streams == nil { tunnelConn.streams = make(map[string]chan []byte) @@ -130,43 +110,41 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne tunnelConn.streams[streamID] = responseChan s.mu.Unlock() - // Clean up when done + // Clean Up defer func() { s.mu.Lock() delete(tunnelConn.streams, streamID) close(responseChan) s.mu.Unlock() - // Send close message + // Send Close closeMsg := types.Message{ - Type: "close", + Type: types.MessageTypeClose, StreamID: streamID, } - tunnelConn.WriteJSON(closeMsg) + _ = tunnelConn.WriteJSON(closeMsg) }() - // Handle client -> tunnel + // Read & Send Chunks go func() { buffer := make([]byte, 4096) for { - n, err := clientConn.Read(buffer) + n, err := dataReader.Read(buffer) if err != nil { return } - msg := types.Message{ - Type: "data", + if err := tunnelConn.WriteJSON(types.Message{ + Type: types.MessageTypeData, StreamID: streamID, Data: buffer[:n], - } - - if err := tunnelConn.WriteJSON(msg); err != nil { + }); err != nil { return } } }() - // Return Client Response Data + // Return Response Data for data := range responseChan { if _, err := clientConn.Write(data); err != nil { break @@ -174,54 +152,59 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne } } +// peakData limits how much we read as we only need to determine +// the host to figure out whether we should proxy or not. +func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) { + peek := make([]byte, 8192) + n, err := conn.Read(peek) + if err != nil { + return nil, nil, err + } + + peekedData := peek[:n] + combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn) + + return bytes.NewReader(peekedData), combinedReader, nil +} + func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() - // Read enough to get the Host header - buffer := make([]byte, 4096) - n, err := conn.Read(buffer) - if err != nil { - return - } - - // Extract host - host := s.extractHostFromHTTP(buffer[:n]) - subdomain := s.extractSubdomain(host) - - // If we have a registered tunnel for this subdomain, relay it - if subdomain != "" { + peakReader, allReader, _ := s.peekData(conn) + if subdomain := s.extractSubdomain(peakReader); subdomain != "" { s.mu.RLock() tunnelConn, exists := s.tunnels[subdomain] s.mu.RUnlock() if exists { log.Infof("Relaying %s to tunnel", subdomain) - s.proxyRawConnection(conn, tunnelConn, buffer[:n]) + + s.proxyRawConnection(conn, tunnelConn, allReader) return } } // Otherwise, handle as control server (recreate HTTP request and use net/http) - s.handleAsHTTP(conn, buffer[:n]) + s.handleAsHTTP(conn, allReader) } -func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) { - // Create a fake HTTP request from the raw data - reader := bufio.NewReader(bytes.NewReader(initialData)) - req, err := http.ReadRequest(reader) +func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) { + // Create HTTP Request & Writer + r, err := http.ReadRequest(bufio.NewReader(allReader)) if err != nil { _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) return } + w := &connResponseWriter{conn: conn} // Handle Control Endpoints - switch req.URL.Path { + switch r.URL.Path { case "/_conduit/tunnel": - s.createTunnel(conn, req) + s.createTunnel(w, r) case "/_conduit/status": - s.getStatus(conn) + s.getStatus(w, r) default: - _, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) + w.WriteHeader(http.StatusNotFound) } } @@ -230,39 +213,53 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { var msg types.Message err := tunnel.ReadJSON(&msg) if err != nil { - break + return } - // Route message to appropriate stream - if msg.Type == "data" && msg.StreamID != "" { + if msg.StreamID == "" { + log.Infof("Tunnel %s missing streamID", tunnel.name) + continue + } + + switch msg.Type { + case types.MessageTypeClose: + return + case types.MessageTypeData: s.mu.RLock() - if streamChan, exists := tunnel.streams[msg.StreamID]; exists { - select { - case streamChan <- msg.Data: - case <-time.After(time.Second): - log.Infof("Stream %s channel full, dropping data", msg.StreamID) - } + streamChan, exists := tunnel.streams[msg.StreamID] + if !exists { + log.Infof("Stream %s does not exist", msg.StreamID) + s.mu.RUnlock() + continue + } + + select { + case streamChan <- msg.Data: + case <-time.After(time.Second): + log.Infof("Stream %s channel full, dropping data", msg.StreamID) } s.mu.RUnlock() } } } -func (s *Server) createTunnel(conn net.Conn, req *http.Request) { +func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Get Tunnel Name - tunnelName := req.URL.Query().Get("tunnelName") + tunnelName := r.URL.Query().Get("tunnelName") if tunnelName == "" { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter")) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("Missing tunnelName parameter")) return } // Validate Unique if _, exists := s.tunnels[tunnelName]; exists { - conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered")) + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte("Tunnel already registered")) return } // Upgrade Connection - wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil) + wsConn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("WebSocket upgrade failed: %v", err) return @@ -284,48 +281,10 @@ func (s *Server) createTunnel(conn net.Conn, req *http.Request) { s.mu.Lock() delete(s.tunnels, tunnelName) s.mu.Unlock() - wsConn.Close() + _ = wsConn.Close() log.Infof("Tunnel closed: %s", tunnelName) }() // Handle tunnel messages s.handleTunnelMessages(tunnel) } - -type rawResponseWriter struct { - conn net.Conn - header http.Header -} - -func (f *rawResponseWriter) Header() http.Header { - if f.header == nil { - f.header = make(http.Header) - } - return f.header -} - -func (f *rawResponseWriter) Write(data []byte) (int, error) { - return f.conn.Write(data) -} - -func (f *rawResponseWriter) WriteHeader(statusCode int) { - // Write Status - status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - _, _ = f.conn.Write([]byte(status)) - - // Write Headers - for key, values := range f.header { - for _, value := range values { - _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) - } - } - - // End Headers - _, _ = f.conn.Write([]byte("\r\n")) -} - -func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - // Return Raw Connection & ReadWriter - rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) - return f.conn, rw, nil -} diff --git a/server/writer.go b/server/writer.go new file mode 100644 index 0000000..601cc40 --- /dev/null +++ b/server/writer.go @@ -0,0 +1,48 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +var _ http.ResponseWriter = (*connResponseWriter)(nil) + +type connResponseWriter struct { + conn net.Conn + header http.Header +} + +func (f *connResponseWriter) Header() http.Header { + if f.header == nil { + f.header = make(http.Header) + } + return f.header +} + +func (f *connResponseWriter) Write(data []byte) (int, error) { + return f.conn.Write(data) +} + +func (f *connResponseWriter) WriteHeader(statusCode int) { + // Write Status + status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + _, _ = f.conn.Write([]byte(status)) + + // Write Headers + for key, values := range f.header { + for _, value := range values { + _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) + } + } + + // End Headers + _, _ = f.conn.Write([]byte("\r\n")) +} + +func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Return Raw Connection & ReadWriter + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +}