diff --git a/client/client.go b/client/client.go index a9786d2..1e0aeec 100644 --- a/client/client.go +++ b/client/client.go @@ -2,17 +2,15 @@ package client import ( "fmt" - "net" "net/url" - "sync" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "reichard.io/conduit/config" - "reichard.io/conduit/types" + "reichard.io/conduit/tunnel" ) -func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) { +func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) { // Parse Server URL serverURL, err := url.Parse(cfg.ServerAddress) if err != nil { @@ -43,117 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) { return nil, fmt.Errorf("failed to connect: %v", err) } - return &Tunnel{ - name: cfg.TunnelName, - target: cfg.TunnelTarget, - serverURL: serverURL, - serverConn: serverConn, - localConns: make(map[string]net.Conn), - }, nil - -} - -type Tunnel struct { - name string - target string - serverURL *url.URL - - serverConn *websocket.Conn - localConns map[string]net.Conn - mu sync.RWMutex -} - -func (t *Tunnel) Start() error { - log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target) - defer t.serverConn.Close() - - // Handle Messages - for { - // Read Message - var msg types.Message - err := t.serverConn.ReadJSON(&msg) - if err != nil { - log.Errorf("error reading from tunnel: %v", err) - break - } - - switch msg.Type { - case types.MessageTypeData: - localConn, err := t.getLocalConn(msg.StreamID) - if err != nil { - log.Errorf("failed to get local connection: %v", err) - continue - } - - // Write data to local connection - if _, err := localConn.Write(msg.Data); err != nil { - log.Errorf("error writing to local connection: %v", err) - localConn.Close() - t.mu.Lock() - delete(t.localConns, msg.StreamID) - t.mu.Unlock() - } - - case types.MessageTypeClose: - t.mu.Lock() - if localConn, exists := t.localConns[msg.StreamID]; exists { - localConn.Close() - delete(t.localConns, msg.StreamID) - } - t.mu.Unlock() - } - } - - return nil -} - -func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { - // Get Cached Connection - t.mu.RLock() - localConn, exists := t.localConns[streamID] - t.mu.RUnlock() - if exists { - return localConn, nil - } - - // Initiate Connection & Cache - localConn, err := net.Dial("tcp", t.target) - if err != nil { - log.Errorf("failed to connect to %s: %v", t.target, err) - return nil, err - } - t.mu.Lock() - t.localConns[streamID] = localConn - t.mu.Unlock() - - // Start Response Relay & Return Connection - go t.startResponseRelay(streamID, localConn) - return localConn, nil -} - -func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) { - defer func() { - t.mu.Lock() - delete(t.localConns, streamID) - t.mu.Unlock() - localConn.Close() - }() - - buffer := make([]byte, 4096) - for { - n, err := localConn.Read(buffer) - if err != nil { - break - } - - response := types.Message{ - Type: types.MessageTypeData, - StreamID: streamID, - Data: buffer[:n], - } - - if err := t.serverConn.WriteJSON(response); err != nil { - break - } - } + return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn), nil } diff --git a/cmd/tunnel.go b/cmd/tunnel.go index 73b64e6..48ace87 100644 --- a/cmd/tunnel.go +++ b/cmd/tunnel.go @@ -25,9 +25,7 @@ var tunnelCmd = &cobra.Command{ // Start Tunnel log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget) - if err := tunnel.Start(); err != nil { - log.Fatal("failed to start tunnel:", err) - } + tunnel.Start() }, } diff --git a/server/server.go b/server/server.go index 6f465b8..b22bd3e 100644 --- a/server/server.go +++ b/server/server.go @@ -151,7 +151,7 @@ func (s *Server) handleRawConnection(conn net.Conn) { // Handle Tunnels s.mu.RLock() - tunnelConn, exists := s.tunnels[subdomain] + conduitTunnel, exists := s.tunnels[subdomain] s.mu.RUnlock() if !exists { w.WriteHeader(http.StatusNotFound) @@ -159,11 +159,17 @@ func (s *Server) handleRawConnection(conn net.Conn) { return } - // Initialize New Stream - log.Infof("relaying %s to tunnel", subdomain) + // Add & Start Stream reconstructedConn := newReconstructedConn(conn, &capturedData) streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) - tunnelConn.NewStream(streamID, reconstructedConn) + if err := conduitTunnel.AddStream(streamID, reconstructedConn); err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = fmt.Fprintf(w, "failed to add stream: %v", err) + return + } + + log.Infof("relaying %s to tunnel", subdomain) + _ = conduitTunnel.StartStream(streamID) } func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { @@ -210,7 +216,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { } // Create Tunnel - conduitTunnel := tunnel.NewTunnel(tunnelName, wsConn) + conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) s.mu.Lock() s.tunnels[tunnelName] = conduitTunnel s.mu.Unlock() diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index e2eeebb..011b54a 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,7 +1,9 @@ package tunnel import ( + "fmt" "io" + "net" "sync" "github.com/gorilla/websocket" @@ -9,7 +11,7 @@ import ( "reichard.io/conduit/types" ) -func NewTunnel(name string, wsConn *websocket.Conn) *Tunnel { +func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { return &Tunnel{ name: name, wsConn: wsConn, @@ -17,16 +19,26 @@ func NewTunnel(name string, wsConn *websocket.Conn) *Tunnel { } } +func NewClientTunnel(name, target string, wsConn *websocket.Conn) *Tunnel { + return &Tunnel{ + name: name, + wsConn: wsConn, + streams: make(map[string]io.ReadWriteCloser), + connBuilder: func() (io.ReadWriteCloser, error) { + return net.Dial("tcp", target) + }, + } +} + type Tunnel struct { - name string - wsConn *websocket.Conn - streams map[string]io.ReadWriteCloser + name string + wsConn *websocket.Conn + streams map[string]io.ReadWriteCloser + connBuilder func() (io.ReadWriteCloser, error) wsMu, streamsMu sync.Mutex } -// Start starts the tunnel and is the primary loop that handles all websocket messages. -// Messages are relayed to the local stream. func (t *Tunnel) Start() { for { var msg types.Message @@ -35,31 +47,71 @@ func (t *Tunnel) Start() { return } + // Validate Stream if msg.StreamID == "" { log.Warnf("tunnel %s missing streamID", t.name) continue } + // Ensure Stream + if err := t.initStreamConnection(msg.StreamID); err != nil { + log.WithError(err).Errorf("failed to initialize stream %s connection", t.name) + continue + } + + // Handle Messages switch msg.Type { case types.MessageTypeClose: - t.CloseStream(msg.StreamID) + _ = t.CloseStream(msg.StreamID) case types.MessageTypeData: - t.WriteStream(msg.StreamID, msg.Data) + _ = t.WriteStream(msg.StreamID, msg.Data) } } } -func (t *Tunnel) NewStream(streamID string, localConn io.ReadWriteCloser) { - t.streamsMu.Lock() - t.streams[streamID] = localConn - t.streamsMu.Unlock() +func (t *Tunnel) initStreamConnection(streamID string) error { + if _, found := t.getStream(streamID); found { + return nil + } + conn, err := t.connBuilder() + if err != nil { + return err + } + + if err := t.AddStream(streamID, conn); err != nil { + return err + } + + go t.StartStream(streamID) + return nil +} + +func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error { + t.streamsMu.Lock() + defer t.streamsMu.Unlock() + + if _, found := t.streams[streamID]; found { + return fmt.Errorf("stream %s already exists", streamID) + } + t.streams[streamID] = conn + return nil +} + +func (t *Tunnel) StartStream(streamID string) error { + // Get Stream + conn, found := t.getStream(streamID) + if !found { + return fmt.Errorf("stream %s does not exist", streamID) + } + + // Start Stream defer t.CloseStream(streamID) buffer := make([]byte, 4096) for { - n, err := localConn.Read(buffer) + n, err := conn.Read(buffer) if err != nil { - return + return err } if err := t.sendWS(&types.Message{ @@ -67,22 +119,23 @@ func (t *Tunnel) NewStream(streamID string, localConn io.ReadWriteCloser) { Data: buffer[:n], StreamID: streamID, }); err != nil { - return + return err } } } -func (t *Tunnel) WriteStream(streamID string, data []byte) { - t.streamsMu.Lock() - defer t.streamsMu.Unlock() - if localConn, ok := t.streams[streamID]; ok { - _, _ = localConn.Write(data) - } else { - log.Infof("stream %s does not exist", streamID) +func (t *Tunnel) WriteStream(streamID string, data []byte) error { + // Get Stream + conn, found := t.getStream(streamID) + if !found { + return fmt.Errorf("stream %s does not exist", streamID) } + + _, err := conn.Write(data) + return err } -func (t *Tunnel) CloseStream(streamID string) { +func (t *Tunnel) CloseStream(streamID string) error { _ = t.sendWS(&types.Message{ Type: types.MessageTypeClose, StreamID: streamID, @@ -90,10 +143,11 @@ func (t *Tunnel) CloseStream(streamID string) { t.streamsMu.Lock() defer t.streamsMu.Unlock() - if localConn, ok := t.streams[streamID]; ok { + if conn, ok := t.streams[streamID]; ok { delete(t.streams, streamID) - _ = localConn.Close() + return conn.Close() } + return nil } func (t *Tunnel) Source() string { @@ -105,3 +159,13 @@ func (t *Tunnel) sendWS(msg *types.Message) error { defer t.wsMu.Unlock() return t.wsConn.WriteJSON(msg) } + +func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) { + t.streamsMu.Lock() + defer t.streamsMu.Unlock() + + if conn, ok := t.streams[streamID]; ok { + return conn, true + } + return nil, false +}