package cmd import ( "fmt" "net" "net/url" "sync" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "reichard.io/conduit/types" ) var serverAddr string var linkCmd = &cobra.Command{ Use: "link ", Short: "Create a tunnel link", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { tunnelName := args[0] tunnelTarget := args[1] // Create Tunnel tunnel, err := NewTunnel(tunnelName, tunnelTarget, serverAddr) if err != nil { log.Fatal("Failed to start tunnel:", err) } // Start Tunnel log.Infof("Creating TCP tunnel: %s -> %s\n", tunnelName, tunnelTarget) if err := tunnel.Start(); err != nil { log.Fatal("Failed to start tunnel:", err) } }, } func init() { linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address") } type TunnelConfig struct { // The conduit server address, e.g. https://conduit.example.com ServerAddress string `default:"http://localhost:8080"` } func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) { // Parse Server URL serverURL, err := url.Parse(serverAddress) if err != nil { return nil, err } // Parse Scheme var wsScheme string switch serverURL.Scheme { case "https": wsScheme = "wss" case "http": wsScheme = "ws" default: return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) } // Connect Server WS wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName) serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { return nil, fmt.Errorf("failed to connect: %v", err) } return &Tunnel{ name: tunnelName, target: tunnelTarget, serverConn: serverConn, localConns: make(map[string]net.Conn), }, nil } type Tunnel struct { name string target string serverConn *websocket.Conn localConns map[string]net.Conn mu sync.RWMutex } func (t *Tunnel) Start() error { log.Infof("TCP Tunnel active! %s.example.com -> %s\n", t.name, t.target) defer t.serverConn.Close() // Handle Messages for { // Read Message var msg types.Message err := t.serverConn.ReadJSON(&msg) if err != nil { log.Errorf("Error reading from tunnel: %v", err) break } switch msg.Type { case types.MessageTypeData: localConn, err := t.getLocalConn(msg.StreamID) if err != nil { log.Errorf("Failed to get local connection: %v", err) continue } // Write data to local connection if _, err := localConn.Write(msg.Data); err != nil { log.Errorf("Error writing to local connection: %v", err) localConn.Close() t.mu.Lock() delete(t.localConns, msg.StreamID) t.mu.Unlock() } case types.MessageTypeClose: t.mu.Lock() if localConn, exists := t.localConns[msg.StreamID]; exists { localConn.Close() delete(t.localConns, msg.StreamID) } t.mu.Unlock() } } return nil } func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { // Get Cached Connection t.mu.RLock() localConn, exists := t.localConns[streamID] t.mu.RUnlock() if exists { return localConn, nil } // Initiate Connection & Cache localConn, err := net.Dial("tcp", t.target) if err != nil { log.Errorf("Failed to connect to %s: %v", t.target, err) return nil, err } t.mu.Lock() t.localConns[streamID] = localConn t.mu.Unlock() // Start Response Relay & Return Connection go t.startResponseRelay(streamID, localConn) return localConn, nil } func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) { defer func() { t.mu.Lock() delete(t.localConns, streamID) t.mu.Unlock() lConn.Close() }() buffer := make([]byte, 4096) for { n, err := lConn.Read(buffer) if err != nil { break } response := types.Message{ Type: types.MessageTypeData, StreamID: streamID, Data: buffer[:n], } if err := t.serverConn.WriteJSON(response); err != nil { break } } }