diff --git a/cmd/link.go b/cmd/link.go index 64de815..dada57b 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "net" + "net/url" "sync" "github.com/gorilla/websocket" @@ -19,43 +20,88 @@ type TunnelMessage struct { var serverAddr string var linkCmd = &cobra.Command{ - Use: "link ", + Use: "link ", Short: "Create a tunnel link", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { - vhostLoc := args[0] - hostPort := args[1] + tunnelName := args[0] + tunnelTarget := args[1] - fmt.Printf("Creating TCP tunnel: %s -> %s\n", vhostLoc, hostPort) + // Create Tunnel + tunnel, err := NewTunnel(tunnelName, tunnelTarget, serverAddr) + if err != nil { + log.Fatal("Failed to start tunnel:", err) + } - if err := startTCPTunnel(vhostLoc, hostPort); err != nil { + // Start Tunnel + log.Infof("Creating TCP tunnel: %s -> %s\n", tunnelName, tunnelTarget) + if err := tunnel.Start(); err != nil { log.Fatal("Failed to start tunnel:", err) } }, } func init() { - linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address") + linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address") } -func startTCPTunnel(vhost, hostPort string) error { - wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost) - conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) +type TunnelConfig struct { + // The conduit server address, e.g. https://conduit.example.com + ServerAddress string `default:"http://localhost:8080"` +} + +func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) { + // Parse Server URL + serverURL, err := url.Parse(serverAddress) if err != nil { - return fmt.Errorf("failed to connect: %v", err) + return nil, err } - defer conn.Close() - fmt.Printf("TCP Tunnel active! %s.example.com -> %s\n", vhost, hostPort) + // Parse Scheme + var wsScheme string + switch serverURL.Scheme { + case "https": + wsScheme = "wss" + case "http": + wsScheme = "ws" + default: + return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) + } - // Track active connections - connections := make(map[string]net.Conn) - var connMutex sync.RWMutex + // Connect Server WS + wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?vhost=%s", wsScheme, serverURL.Host, tunnelName) + serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } - // Handle messages from server + return &Tunnel{ + name: tunnelName, + target: tunnelTarget, + serverConn: serverConn, + localConns: make(map[string]net.Conn), + }, nil + +} + +type Tunnel struct { + name string + target string + + serverConn *websocket.Conn + localConns map[string]net.Conn + mu sync.RWMutex +} + +func (t *Tunnel) Start() error { + log.Infof("TCP Tunnel active! %s.example.com -> %s\n", t.name, t.target) + defer t.serverConn.Close() + + // Handle Messages for { + // Read Message var msg TunnelMessage - err := conn.ReadJSON(&msg) + err := t.serverConn.ReadJSON(&msg) if err != nil { log.Errorf("Error reading from tunnel: %v", err) break @@ -63,69 +109,81 @@ func startTCPTunnel(vhost, hostPort string) error { switch msg.Type { case "data": - connMutex.RLock() - localConn, exists := connections[msg.StreamID] - connMutex.RUnlock() - - if !exists { - // New connection - localConn, err = net.Dial("tcp", hostPort) - if err != nil { - log.Errorf("Failed to connect to %s: %v", hostPort, err) - continue - } - - connMutex.Lock() - connections[msg.StreamID] = localConn - connMutex.Unlock() - - // Start reading from local connection - go func(streamID string, lConn net.Conn) { - defer func() { - connMutex.Lock() - delete(connections, streamID) - connMutex.Unlock() - lConn.Close() - }() - - buffer := make([]byte, 4096) - for { - n, err := lConn.Read(buffer) - if err != nil { - break - } - - response := TunnelMessage{ - Type: "data", - StreamID: streamID, - Data: buffer[:n], - } - - if err := conn.WriteJSON(response); err != nil { - break - } - } - }(msg.StreamID, localConn) + localConn, err := t.getLocalConn(msg.StreamID) + if err != nil { + log.Errorf("Failed to get local connection: %v", err) + continue } // Write data to local connection if _, err := localConn.Write(msg.Data); err != nil { log.Errorf("Error writing to local connection: %v", err) localConn.Close() - connMutex.Lock() - delete(connections, msg.StreamID) - connMutex.Unlock() + t.mu.Lock() + delete(t.localConns, msg.StreamID) + t.mu.Unlock() } case "close": - connMutex.Lock() - if localConn, exists := connections[msg.StreamID]; exists { + t.mu.Lock() + if localConn, exists := t.localConns[msg.StreamID]; exists { localConn.Close() - delete(connections, msg.StreamID) + delete(t.localConns, msg.StreamID) } - connMutex.Unlock() + t.mu.Unlock() } } return nil } + +func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { + // Get Cached Connection + t.mu.RLock() + localConn, exists := t.localConns[streamID] + t.mu.RUnlock() + if exists { + return localConn, nil + } + + // Initiate Connection & Cache + localConn, err := net.Dial("tcp", t.target) + if err != nil { + log.Errorf("Failed to connect to %s: %v", t.target, err) + return nil, err + } + t.mu.Lock() + t.localConns[streamID] = localConn + t.mu.Unlock() + + // Start Response Relay & Return Connection + go t.startResponseRelay(streamID, localConn) + return localConn, nil +} + +func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) { + defer func() { + t.mu.Lock() + delete(t.localConns, streamID) + t.mu.Unlock() + lConn.Close() + }() + + buffer := make([]byte, 4096) + for { + n, err := lConn.Read(buffer) + if err != nil { + break + } + + response := TunnelMessage{ + Type: "data", + StreamID: streamID, + Data: buffer[:n], + } + + if err := t.serverConn.WriteJSON(response); err != nil { + break + } + } +} diff --git a/server/server.go b/server/server.go index 698bc3d..620f637 100644 --- a/server/server.go +++ b/server/server.go @@ -171,7 +171,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne } }() - // Handle tunnel -> client (read from response channel) + // Return Client Response Data for data := range responseChan { if _, err := clientConn.Write(data); err != nil { break