diff --git a/cmd/link.go b/cmd/link.go index dada57b..5fd622d 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -9,14 +9,9 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "reichard.io/conduit/types" ) -type TunnelMessage struct { - Type string `json:"type"` - StreamID string `json:"stream_id"` - Data []byte `json:"data,omitempty"` -} - var serverAddr string var linkCmd = &cobra.Command{ @@ -69,7 +64,7 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) } // Connect Server WS - wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?vhost=%s", wsScheme, serverURL.Host, tunnelName) + wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName) serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { return nil, fmt.Errorf("failed to connect: %v", err) @@ -100,7 +95,7 @@ func (t *Tunnel) Start() error { // Handle Messages for { // Read Message - var msg TunnelMessage + var msg types.Message err := t.serverConn.ReadJSON(&msg) if err != nil { log.Errorf("Error reading from tunnel: %v", err) @@ -108,7 +103,7 @@ func (t *Tunnel) Start() error { } switch msg.Type { - case "data": + case types.MessageTypeData: localConn, err := t.getLocalConn(msg.StreamID) if err != nil { log.Errorf("Failed to get local connection: %v", err) @@ -124,7 +119,7 @@ func (t *Tunnel) Start() error { t.mu.Unlock() } - case "close": + case types.MessageTypeClose: t.mu.Lock() if localConn, exists := t.localConns[msg.StreamID]; exists { localConn.Close() @@ -176,8 +171,8 @@ func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) { break } - response := TunnelMessage{ - Type: "data", + response := types.Message{ + Type: types.MessageTypeData, StreamID: streamID, Data: buffer[:n], } diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..d912156 --- /dev/null +++ b/config/config.go @@ -0,0 +1 @@ +package config diff --git a/server/server.go b/server/server.go index 620f637..d066646 100644 --- a/server/server.go +++ b/server/server.go @@ -12,20 +12,15 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" + "reichard.io/conduit/types" ) type TunnelConnection struct { *websocket.Conn - vhost string + name string streams map[string]chan []byte // StreamID -> data channel } -type TunnelMessage struct { - Type string `json:"type"` - StreamID string `json:"stream_id"` - Data []byte `json:"data,omitempty"` -} - type Server struct { tunnels map[string]*TunnelConnection mu sync.RWMutex @@ -77,7 +72,7 @@ func (s *Server) extractSubdomain(host string) string { return "" } -func (s *Server) handleStatus(conn net.Conn) { +func (s *Server) getStatus(conn net.Conn) { s.mu.RLock() count := len(s.tunnels) s.mu.RUnlock() @@ -113,7 +108,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) // Send initial data with stream ID - msg := TunnelMessage{ + msg := types.Message{ Type: "data", StreamID: streamID, Data: initialData, @@ -143,7 +138,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne s.mu.Unlock() // Send close message - closeMsg := TunnelMessage{ + closeMsg := types.Message{ Type: "close", StreamID: streamID, } @@ -159,7 +154,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne return } - msg := TunnelMessage{ + msg := types.Message{ Type: "data", StreamID: streamID, Data: buffer[:n], @@ -215,25 +210,24 @@ func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) { reader := bufio.NewReader(bytes.NewReader(initialData)) req, err := http.ReadRequest(reader) if err != nil { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) return } // Handle Control Endpoints switch req.URL.Path { case "/_conduit/tunnel": - s.handleTunnelUpgrade(conn, req) - return + s.createTunnel(conn, req) case "/_conduit/status": - s.handleStatus(conn) + s.getStatus(conn) default: - conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) + _, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) } } func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { for { - var msg TunnelMessage + var msg types.Message err := tunnel.ReadJSON(&msg) if err != nil { break @@ -253,84 +247,85 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { } } } -func (s *Server) handleTunnelUpgrade(conn net.Conn, req *http.Request) { - vhost := req.URL.Query().Get("vhost") - if vhost == "" { - conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing vhost parameter")) +func (s *Server) createTunnel(conn net.Conn, req *http.Request) { + // Get Tunnel Name + tunnelName := req.URL.Query().Get("tunnelName") + if tunnelName == "" { + conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter")) return } - // Create a fake ResponseWriter that writes to our raw connection - fakeWriter := &fakeResponseWriter{conn: conn} + // Validate Unique + if _, exists := s.tunnels[tunnelName]; exists { + conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered")) + return + } - // Use the upgrader - wsConn, err := s.upgrader.Upgrade(fakeWriter, req, nil) + // Upgrade Connection + wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil) if err != nil { log.Errorf("WebSocket upgrade failed: %v", err) return } - // Create TunnelConnection + // Create & Cache TunnelConnection tunnel := &TunnelConnection{ Conn: wsConn, - vhost: vhost, + name: tunnelName, streams: make(map[string]chan []byte), } - s.mu.Lock() - s.tunnels[vhost] = tunnel + s.tunnels[tunnelName] = tunnel s.mu.Unlock() - - log.Infof("Tunnel established: %s", vhost) + log.Infof("Tunnel established: %s", tunnelName) // Keep connection alive and handle cleanup defer func() { s.mu.Lock() - delete(s.tunnels, vhost) + delete(s.tunnels, tunnelName) s.mu.Unlock() wsConn.Close() - log.Infof("Tunnel closed: %s", vhost) + log.Infof("Tunnel closed: %s", tunnelName) }() // Handle tunnel messages s.handleTunnelMessages(tunnel) } -type fakeResponseWriter struct { +type rawResponseWriter struct { conn net.Conn header http.Header } -func (f *fakeResponseWriter) Header() http.Header { +func (f *rawResponseWriter) Header() http.Header { if f.header == nil { f.header = make(http.Header) } return f.header } -func (f *fakeResponseWriter) Write(data []byte) (int, error) { +func (f *rawResponseWriter) Write(data []byte) (int, error) { return f.conn.Write(data) } -func (f *fakeResponseWriter) WriteHeader(statusCode int) { - // Write HTTP status line +func (f *rawResponseWriter) WriteHeader(statusCode int) { + // Write Status status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - f.conn.Write([]byte(status)) + _, _ = f.conn.Write([]byte(status)) - // Write headers + // Write Headers for key, values := range f.header { for _, value := range values { - f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))) + _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) } } - // End headers - f.conn.Write([]byte("\r\n")) + // End Headers + _, _ = f.conn.Write([]byte("\r\n")) } -// Implement http.Hijacker -func (f *fakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - // Return the raw connection and create a ReadWriter for it +func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Return Raw Connection & ReadWriter rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) return f.conn, rw, nil } diff --git a/server/tunnel.go b/server/tunnel.go deleted file mode 100644 index 439dceb..0000000 --- a/server/tunnel.go +++ /dev/null @@ -1,45 +0,0 @@ -package server - -import ( - "bufio" - "fmt" - "net" - "net/http" -) - -type tunnelWriter struct { - conn net.Conn - header http.Header -} - -func (f *tunnelWriter) Header() http.Header { - if f.header == nil { - f.header = make(http.Header) - } - return f.header -} - -func (f *tunnelWriter) Write(data []byte) (int, error) { - return f.conn.Write(data) -} - -func (f *tunnelWriter) WriteHeader(statusCode int) { - // Write HTTP status line - status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - f.conn.Write([]byte(status)) - - // Write headers - for key, values := range f.header { - for _, value := range values { - f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))) - } - } - - // End headers - f.conn.Write([]byte("\r\n")) -} - -func (f *tunnelWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) - return f.conn, rw, nil -} diff --git a/types/message.go b/types/message.go new file mode 100644 index 0000000..3d54160 --- /dev/null +++ b/types/message.go @@ -0,0 +1,14 @@ +package types + +type MessageType string + +const ( + MessageTypeData MessageType = "data" + MessageTypeClose MessageType = "close" +) + +type Message struct { + Type MessageType `json:"type"` + StreamID string `json:"stream_id"` + Data []byte `json:"data,omitempty"` +}