diff --git a/.golangci.toml b/.golangci.toml new file mode 100644 index 0000000..bc546a4 --- /dev/null +++ b/.golangci.toml @@ -0,0 +1,6 @@ +#:schema https://golangci-lint.run/jsonschema/golangci.jsonschema.json +version = "2" + +[[linters.exclusions.rules]] +linters = [ "errcheck" ] +source = "^\\s*defer\\s+" diff --git a/cmd/link.go b/cmd/link.go index 7d547e4..bd00dbb 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -1,44 +1,34 @@ package cmd import ( - "bytes" "fmt" - "io" "log" - "net/http" + "net" + "sync" "github.com/gorilla/websocket" "github.com/spf13/cobra" ) -type TunnelRequest struct { - ID string `json:"id"` - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body []byte `json:"body"` -} - -type TunnelResponse struct { - ID string `json:"id"` - StatusCode int `json:"status_code"` - Headers map[string]string `json:"headers"` - Body []byte `json:"body"` +type TunnelMessage struct { + Type string `json:"type"` + StreamID string `json:"stream_id"` + Data []byte `json:"data,omitempty"` } var serverAddr string var linkCmd = &cobra.Command{ - Use: "link ", + Use: "link ", Short: "Create a tunnel link", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { vhostLoc := args[0] - localPort := args[1] + hostPort := args[1] - fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort) + fmt.Printf("Creating TCP tunnel: %s -> %s\n", vhostLoc, hostPort) - if err := startTunnel(vhostLoc, localPort); err != nil { + if err := startTCPTunnel(vhostLoc, hostPort); err != nil { log.Fatal("Failed to start tunnel:", err) } }, @@ -48,8 +38,7 @@ func init() { linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address") } -func startTunnel(vhost, localPort string) error { - // Connect to WebSocket +func startTCPTunnel(vhost, hostPort string) error { wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost) conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { @@ -57,80 +46,86 @@ func startTunnel(vhost, localPort string) error { } defer conn.Close() - fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort) + fmt.Printf("TCP Tunnel active! %s.example.com -> %s\n", vhost, hostPort) - // Handle incoming requests + // Track active connections + connections := make(map[string]net.Conn) + var connMutex sync.RWMutex + + // Handle messages from server for { - var req TunnelRequest - if err := conn.ReadJSON(&req); err != nil { - log.Printf("Error reading request: %v", err) + var msg TunnelMessage + err := conn.ReadJSON(&msg) + if err != nil { + log.Printf("Error reading from tunnel: %v", err) break } - go handleTunnelRequest(conn, &req, localPort) + switch msg.Type { + case "data": + connMutex.RLock() + localConn, exists := connections[msg.StreamID] + connMutex.RUnlock() + + if !exists { + // New connection + localConn, err = net.Dial("tcp", hostPort) + if err != nil { + log.Printf("Failed to connect to %s: %v", hostPort, err) + continue + } + + connMutex.Lock() + connections[msg.StreamID] = localConn + connMutex.Unlock() + + // Start reading from local connection + go func(streamID string, lConn net.Conn) { + defer func() { + connMutex.Lock() + delete(connections, streamID) + connMutex.Unlock() + lConn.Close() + }() + + buffer := make([]byte, 4096) + for { + n, err := lConn.Read(buffer) + if err != nil { + break + } + + response := TunnelMessage{ + Type: "data", + StreamID: streamID, + Data: buffer[:n], + } + + if err := conn.WriteJSON(response); err != nil { + break + } + } + }(msg.StreamID, localConn) + } + + // Write data to local connection + if _, err := localConn.Write(msg.Data); err != nil { + log.Printf("Error writing to local connection: %v", err) + localConn.Close() + connMutex.Lock() + delete(connections, msg.StreamID) + connMutex.Unlock() + } + + case "close": + connMutex.Lock() + if localConn, exists := connections[msg.StreamID]; exists { + localConn.Close() + delete(connections, msg.StreamID) + } + connMutex.Unlock() + } } return nil } - -func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) { - // Make request to local service - localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL) - - httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body)) - if err != nil { - sendErrorResponse(conn, req.ID, 500, "Failed to create request") - return - } - - // Set headers - for k, v := range req.Headers { - httpReq.Header.Set(k, v) - } - - // Make the request - client := &http.Client{} - resp, err := client.Do(httpReq) - if err != nil { - sendErrorResponse(conn, req.ID, 502, "Failed to reach local service") - return - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(resp.Body) - if err != nil { - sendErrorResponse(conn, req.ID, 500, "Failed to read response") - return - } - - // Convert response headers - headers := make(map[string]string) - for k, v := range resp.Header { - if len(v) > 0 { - headers[k] = v[0] - } - } - - // Send response back - tunnelResp := &TunnelResponse{ - ID: req.ID, - StatusCode: resp.StatusCode, - Headers: headers, - Body: body, - } - - if err := conn.WriteJSON(tunnelResp); err != nil { - log.Printf("Error sending response: %v", err) - } -} - -func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) { - resp := &TunnelResponse{ - ID: reqID, - StatusCode: statusCode, - Headers: map[string]string{"Content-Type": "text/plain"}, - Body: []byte(message), - } - conn.WriteJSON(resp) -} diff --git a/server/server.go b/server/server.go index 89856ab..8dfa19b 100644 --- a/server/server.go +++ b/server/server.go @@ -1,9 +1,11 @@ package server import ( + "bufio" + "bytes" "fmt" - "io" "log" + "net" "net/http" "strings" "sync" @@ -13,25 +15,15 @@ import ( ) type TunnelConnection struct { - conn *websocket.Conn - vhost string - responses map[string]chan *TunnelResponse - mu sync.RWMutex + *websocket.Conn + vhost string + streams map[string]chan []byte // StreamID -> data channel } -type TunnelRequest struct { - ID string `json:"id"` - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body []byte `json:"body"` -} - -type TunnelResponse struct { - ID string `json:"id"` - StatusCode int `json:"status_code"` - Headers map[string]string `json:"headers"` - Body []byte `json:"body"` +type TunnelMessage struct { + Type string `json:"type"` + StreamID string `json:"stream_id"` + Data []byte `json:"data,omitempty"` } type Server struct { @@ -52,177 +44,26 @@ func NewServer() *Server { } func (s *Server) Start(addr string) error { - http.HandleFunc("/", s.handleRequest) - http.HandleFunc("/_conduit/tunnel", s.handleTunnel) - http.HandleFunc("/_conduit/status", s.handleStatus) + // Raw TCP listener instead of http.ListenAndServe + listener, err := net.Listen("tcp", addr) + if err != nil { + return err + } + defer listener.Close() log.Printf("Conduit server listening on %s", addr) - return http.ListenAndServe(addr, nil) -} -func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { - host := r.Host - subdomain := s.extractSubdomain(host) - - if subdomain == "" { - http.Error(w, "Invalid host", http.StatusBadRequest) - return - } - - s.mu.RLock() - tunnel, exists := s.tunnels[subdomain] - s.mu.RUnlock() - - if !exists { - http.Error(w, "Tunnel not found", http.StatusNotFound) - return - } - - // Create unique request ID - reqID := fmt.Sprintf("%d", time.Now().UnixNano()) - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Convert headers - headers := make(map[string]string) - for k, v := range r.Header { - if len(v) > 0 { - headers[k] = v[0] - } - } - - // Create response channel for this request - respChan := make(chan *TunnelResponse, 1) - tunnel.mu.Lock() - tunnel.responses[reqID] = respChan - tunnel.mu.Unlock() - - // Clean up response channel when done - defer func() { - tunnel.mu.Lock() - delete(tunnel.responses, reqID) - tunnel.mu.Unlock() - }() - - tunnelReq := &TunnelRequest{ - ID: reqID, - Method: r.Method, - URL: r.URL.String(), - Headers: headers, - Body: body, - } - - // Send request to tunnel - if err := tunnel.conn.WriteJSON(tunnelReq); err != nil { - log.Printf("Error sending request to tunnel: %v", err) - http.Error(w, "Tunnel communication error", http.StatusServiceUnavailable) - return - } - - // Wait for response - select { - case resp := <-respChan: - // Write response headers - for k, v := range resp.Headers { - w.Header().Set(k, v) - } - w.WriteHeader(resp.StatusCode) - w.Write(resp.Body) - - case <-time.After(30 * time.Second): - http.Error(w, "Tunnel timeout", http.StatusGatewayTimeout) - } -} - -func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) { - vhost := r.URL.Query().Get("vhost") - if vhost == "" { - http.Error(w, "Missing vhost parameter", http.StatusBadRequest) - return - } - - conn, err := s.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Printf("WebSocket upgrade failed: %v", err) - return - } - - tunnel := &TunnelConnection{ - conn: conn, - vhost: vhost, - responses: make(map[string]chan *TunnelResponse), - } - - s.mu.Lock() - s.tunnels[vhost] = tunnel - s.mu.Unlock() - - log.Printf("Tunnel established: %s", vhost) - - // Handle tunnel communication - s.handleTunnelConnection(tunnel) -} - -func (s *Server) handleTunnelConnection(tunnel *TunnelConnection) { - defer func() { - s.mu.Lock() - delete(s.tunnels, tunnel.vhost) - s.mu.Unlock() - tunnel.conn.Close() - log.Printf("Tunnel closed: %s", tunnel.vhost) - }() - - // Handle incoming responses from client for { - var resp TunnelResponse - if err := tunnel.conn.ReadJSON(&resp); err != nil { - log.Printf("Error reading response from tunnel %s: %v", tunnel.vhost, err) - return + conn, err := listener.Accept() + if err != nil { + log.Printf("Error accepting connection: %v", err) + continue } - // Find the response channel for this request - tunnel.mu.RLock() - respChan, exists := tunnel.responses[resp.ID] - tunnel.mu.RUnlock() - - if exists { - select { - case respChan <- &resp: - // Response delivered - default: - log.Printf("Response channel full for request %s", resp.ID) - } - } else { - log.Printf("No response channel found for request %s", resp.ID) - } + go s.handleRawConnection(conn) } } -func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { - s.mu.RLock() - defer s.mu.RUnlock() - - w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"tunnels": %d, "active": [`, len(s.tunnels)) - - first := true - for vhost := range s.tunnels { - if !first { - fmt.Fprint(w, ",") - } - fmt.Fprintf(w, `{"vhost":"%s"}`, vhost) - first = false - } - - fmt.Fprint(w, "]}") -} - func (s *Server) extractSubdomain(host string) string { if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] @@ -235,3 +76,259 @@ func (s *Server) extractSubdomain(host string) string { return "" } + +func (s *Server) handleStatus(conn net.Conn) { + s.mu.RLock() + count := len(s.tunnels) + s.mu.RUnlock() + + response := fmt.Sprintf( + "HTTP/1.1 200 OK\r\n"+ + "Content-Type: application/json\r\n"+ + "Content-Length: %d\r\n\r\n"+ + `{"tunnels": %d}`, + len(fmt.Sprintf(`{"tunnels": %d}`, count)), count) + + conn.Write([]byte(response)) +} + +func (s *Server) extractHostFromHTTP(data []byte) string { + // Simple HTTP header parsing + lines := strings.Split(string(data), "\r\n") + for _, line := range lines { + if strings.HasPrefix(strings.ToLower(line), "host:") { + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + return strings.TrimSpace(parts[1]) + } + } + } + return "" +} + +func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, initialData []byte) { + defer clientConn.Close() + + // Generate a unique stream ID for this connection + streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) + + // Send initial data with stream ID + msg := TunnelMessage{ + Type: "data", + StreamID: streamID, + Data: initialData, + } + + if err := tunnelConn.WriteJSON(msg); err != nil { + log.Printf("Error sending initial data: %v", err) + return + } + + // Create a channel for this stream's responses + responseChan := make(chan []byte, 100) + + // Register this stream + s.mu.Lock() + if tunnelConn.streams == nil { + tunnelConn.streams = make(map[string]chan []byte) + } + tunnelConn.streams[streamID] = responseChan + s.mu.Unlock() + + // Clean up when done + defer func() { + s.mu.Lock() + delete(tunnelConn.streams, streamID) + close(responseChan) + s.mu.Unlock() + + // Send close message + closeMsg := TunnelMessage{ + Type: "close", + StreamID: streamID, + } + tunnelConn.WriteJSON(closeMsg) + }() + + // Handle client -> tunnel + go func() { + buffer := make([]byte, 4096) + for { + n, err := clientConn.Read(buffer) + if err != nil { + return + } + + msg := TunnelMessage{ + Type: "data", + StreamID: streamID, + Data: buffer[:n], + } + + if err := tunnelConn.WriteJSON(msg); err != nil { + return + } + } + }() + + // Handle tunnel -> client (read from response channel) + for data := range responseChan { + if _, err := clientConn.Write(data); err != nil { + break + } + } +} + +func (s *Server) handleRawConnection(conn net.Conn) { + defer conn.Close() + + // Read enough to get the Host header + buffer := make([]byte, 4096) + n, err := conn.Read(buffer) + if err != nil { + return + } + + // Extract host + host := s.extractHostFromHTTP(buffer[:n]) + subdomain := s.extractSubdomain(host) + + // If we have a registered tunnel for this subdomain, relay it + if subdomain != "" { + s.mu.RLock() + tunnelConn, exists := s.tunnels[subdomain] + s.mu.RUnlock() + + if exists { + log.Printf("Relaying %s to tunnel", subdomain) + s.proxyRawConnection(conn, tunnelConn, buffer[:n]) + return + } + } + + // Otherwise, handle as control server (recreate HTTP request and use net/http) + s.handleAsHTTP(conn, buffer[:n]) +} + +func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) { + // Create a fake HTTP request from the raw data + reader := bufio.NewReader(bytes.NewReader(initialData)) + req, err := http.ReadRequest(reader) + if err != nil { + conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + return + } + + // Handle control endpoints + if req.URL.Path == "/_conduit/tunnel" { + s.handleTunnelUpgrade(conn, req) + } else if req.URL.Path == "/_conduit/status" { + s.handleStatus(conn) + } else { + conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) + } +} + +func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { + for { + var msg TunnelMessage + err := tunnel.ReadJSON(&msg) + if err != nil { + break + } + + // Route message to appropriate stream + if msg.Type == "data" && msg.StreamID != "" { + s.mu.RLock() + if streamChan, exists := tunnel.streams[msg.StreamID]; exists { + select { + case streamChan <- msg.Data: + case <-time.After(time.Second): + log.Printf("Stream %s channel full, dropping data", msg.StreamID) + } + } + s.mu.RUnlock() + } + } +} +func (s *Server) handleTunnelUpgrade(conn net.Conn, req *http.Request) { + vhost := req.URL.Query().Get("vhost") + if vhost == "" { + conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing vhost parameter")) + return + } + + // Create a fake ResponseWriter that writes to our raw connection + fakeWriter := &fakeResponseWriter{conn: conn} + + // Use the upgrader + wsConn, err := s.upgrader.Upgrade(fakeWriter, req, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + + // Create TunnelConnection + tunnel := &TunnelConnection{ + Conn: wsConn, + vhost: vhost, + streams: make(map[string]chan []byte), + } + + s.mu.Lock() + s.tunnels[vhost] = tunnel + s.mu.Unlock() + + log.Printf("Tunnel established: %s", vhost) + + // Keep connection alive and handle cleanup + defer func() { + s.mu.Lock() + delete(s.tunnels, vhost) + s.mu.Unlock() + wsConn.Close() + log.Printf("Tunnel closed: %s", vhost) + }() + + // Handle tunnel messages + s.handleTunnelMessages(tunnel) +} + +type fakeResponseWriter struct { + conn net.Conn + header http.Header +} + +func (f *fakeResponseWriter) Header() http.Header { + if f.header == nil { + f.header = make(http.Header) + } + return f.header +} + +func (f *fakeResponseWriter) Write(data []byte) (int, error) { + return f.conn.Write(data) +} + +func (f *fakeResponseWriter) WriteHeader(statusCode int) { + // Write HTTP status line + status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + f.conn.Write([]byte(status)) + + // Write headers + for key, values := range f.header { + for _, value := range values { + f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))) + } + } + + // End headers + f.conn.Write([]byte("\r\n")) +} + +// Implement http.Hijacker +func (f *fakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Return the raw connection and create a ReadWriter for it + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} diff --git a/server/tunnel.go b/server/tunnel.go new file mode 100644 index 0000000..439dceb --- /dev/null +++ b/server/tunnel.go @@ -0,0 +1,45 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +type tunnelWriter struct { + conn net.Conn + header http.Header +} + +func (f *tunnelWriter) Header() http.Header { + if f.header == nil { + f.header = make(http.Header) + } + return f.header +} + +func (f *tunnelWriter) Write(data []byte) (int, error) { + return f.conn.Write(data) +} + +func (f *tunnelWriter) WriteHeader(statusCode int) { + // Write HTTP status line + status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + f.conn.Write([]byte(status)) + + // Write headers + for key, values := range f.header { + for _, value := range values { + f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))) + } + } + + // End headers + f.conn.Write([]byte("\r\n")) +} + +func (f *tunnelWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +}