package server import ( "bufio" "bytes" "fmt" "net" "net/http" "strings" "sync" "time" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "reichard.io/conduit/types" ) type TunnelConnection struct { *websocket.Conn name string streams map[string]chan []byte // StreamID -> data channel } type Server struct { tunnels map[string]*TunnelConnection mu sync.RWMutex upgrader websocket.Upgrader } func NewServer() *Server { return &Server{ tunnels: make(map[string]*TunnelConnection), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, } } func (s *Server) Start(addr string) error { // Raw TCP listener instead of http.ListenAndServe listener, err := net.Listen("tcp", addr) if err != nil { return err } defer listener.Close() log.Infof("Conduit server listening on %s", addr) for { conn, err := listener.Accept() if err != nil { log.Printf("Error accepting connection: %v", err) continue } go s.handleRawConnection(conn) } } func (s *Server) extractSubdomain(host string) string { if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } parts := strings.Split(host, ".") if len(parts) >= 1 { return parts[0] } return "" } func (s *Server) getStatus(conn net.Conn) { s.mu.RLock() count := len(s.tunnels) s.mu.RUnlock() response := fmt.Sprintf( "HTTP/1.1 200 OK\r\n"+ "Content-Type: application/json\r\n"+ "Content-Length: %d\r\n\r\n"+ `{"tunnels": %d}`, len(fmt.Sprintf(`{"tunnels": %d}`, count)), count) conn.Write([]byte(response)) } func (s *Server) extractHostFromHTTP(data []byte) string { // Simple HTTP header parsing lines := strings.Split(string(data), "\r\n") for _, line := range lines { if strings.HasPrefix(strings.ToLower(line), "host:") { parts := strings.SplitN(line, ":", 2) if len(parts) == 2 { return strings.TrimSpace(parts[1]) } } } return "" } func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, initialData []byte) { defer clientConn.Close() // Generate a unique stream ID for this connection streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) // Send initial data with stream ID msg := types.Message{ Type: "data", StreamID: streamID, Data: initialData, } if err := tunnelConn.WriteJSON(msg); err != nil { log.Errorf("Error sending initial data: %v", err) return } // Create a channel for this stream's responses responseChan := make(chan []byte, 100) // Register this stream s.mu.Lock() if tunnelConn.streams == nil { tunnelConn.streams = make(map[string]chan []byte) } tunnelConn.streams[streamID] = responseChan s.mu.Unlock() // Clean up when done defer func() { s.mu.Lock() delete(tunnelConn.streams, streamID) close(responseChan) s.mu.Unlock() // Send close message closeMsg := types.Message{ Type: "close", StreamID: streamID, } tunnelConn.WriteJSON(closeMsg) }() // Handle client -> tunnel go func() { buffer := make([]byte, 4096) for { n, err := clientConn.Read(buffer) if err != nil { return } msg := types.Message{ Type: "data", StreamID: streamID, Data: buffer[:n], } if err := tunnelConn.WriteJSON(msg); err != nil { return } } }() // Return Client Response Data for data := range responseChan { if _, err := clientConn.Write(data); err != nil { break } } } func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() // Read enough to get the Host header buffer := make([]byte, 4096) n, err := conn.Read(buffer) if err != nil { return } // Extract host host := s.extractHostFromHTTP(buffer[:n]) subdomain := s.extractSubdomain(host) // If we have a registered tunnel for this subdomain, relay it if subdomain != "" { s.mu.RLock() tunnelConn, exists := s.tunnels[subdomain] s.mu.RUnlock() if exists { log.Infof("Relaying %s to tunnel", subdomain) s.proxyRawConnection(conn, tunnelConn, buffer[:n]) return } } // Otherwise, handle as control server (recreate HTTP request and use net/http) s.handleAsHTTP(conn, buffer[:n]) } func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) { // Create a fake HTTP request from the raw data reader := bufio.NewReader(bytes.NewReader(initialData)) req, err := http.ReadRequest(reader) if err != nil { _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) return } // Handle Control Endpoints switch req.URL.Path { case "/_conduit/tunnel": s.createTunnel(conn, req) case "/_conduit/status": s.getStatus(conn) default: _, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) } } func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { for { var msg types.Message err := tunnel.ReadJSON(&msg) if err != nil { break } // Route message to appropriate stream if msg.Type == "data" && msg.StreamID != "" { s.mu.RLock() if streamChan, exists := tunnel.streams[msg.StreamID]; exists { select { case streamChan <- msg.Data: case <-time.After(time.Second): log.Infof("Stream %s channel full, dropping data", msg.StreamID) } } s.mu.RUnlock() } } } func (s *Server) createTunnel(conn net.Conn, req *http.Request) { // Get Tunnel Name tunnelName := req.URL.Query().Get("tunnelName") if tunnelName == "" { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter")) return } // Validate Unique if _, exists := s.tunnels[tunnelName]; exists { conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered")) return } // Upgrade Connection wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil) if err != nil { log.Errorf("WebSocket upgrade failed: %v", err) return } // Create & Cache TunnelConnection tunnel := &TunnelConnection{ Conn: wsConn, name: tunnelName, streams: make(map[string]chan []byte), } s.mu.Lock() s.tunnels[tunnelName] = tunnel s.mu.Unlock() log.Infof("Tunnel established: %s", tunnelName) // Keep connection alive and handle cleanup defer func() { s.mu.Lock() delete(s.tunnels, tunnelName) s.mu.Unlock() wsConn.Close() log.Infof("Tunnel closed: %s", tunnelName) }() // Handle tunnel messages s.handleTunnelMessages(tunnel) } type rawResponseWriter struct { conn net.Conn header http.Header } func (f *rawResponseWriter) Header() http.Header { if f.header == nil { f.header = make(http.Header) } return f.header } func (f *rawResponseWriter) Write(data []byte) (int, error) { return f.conn.Write(data) } func (f *rawResponseWriter) WriteHeader(statusCode int) { // Write Status status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) _, _ = f.conn.Write([]byte(status)) // Write Headers for key, values := range f.header { for _, value := range values { _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) } } // End Headers _, _ = f.conn.Write([]byte("\r\n")) } func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Return Raw Connection & ReadWriter rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) return f.conn, rw, nil }