package server import ( "bufio" "bytes" "fmt" "io" "net" "net/http" "strings" "sync" "time" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "reichard.io/conduit/types" ) type TunnelConnection struct { *websocket.Conn name string streams map[string]chan []byte // StreamID -> data channel } type Server struct { tunnels map[string]*TunnelConnection upgrader websocket.Upgrader mu sync.RWMutex } func NewServer() *Server { return &Server{ tunnels: make(map[string]*TunnelConnection), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, } } func (s *Server) Start(addr string) error { // Raw TCP listener instead of http.ListenAndServe listener, err := net.Listen("tcp", addr) if err != nil { return err } defer listener.Close() log.Infof("Conduit server listening on %s", addr) for { conn, err := listener.Accept() if err != nil { log.Printf("Error accepting connection: %v", err) continue } go s.handleRawConnection(conn) } } func (s *Server) extractSubdomain(peakReader io.Reader) string { // Read Request req, err := http.ReadRequest(bufio.NewReader(peakReader)) if err != nil { return "" } defer req.Body.Close() // Extract Host host := req.Host if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } // Extract Subdomain parts := strings.Split(host, ".") if len(parts) >= 1 { return parts[0] } return "" } func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) { s.mu.RLock() count := len(s.tunnels) s.mu.RUnlock() w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) response := fmt.Sprintf(`{"tunnels": %d}`, count) _, _ = w.Write([]byte(response)) } func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) { defer clientConn.Close() // Create Identifiers streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) responseChan := make(chan []byte, 100) // Register Stream s.mu.Lock() if tunnelConn.streams == nil { tunnelConn.streams = make(map[string]chan []byte) } tunnelConn.streams[streamID] = responseChan s.mu.Unlock() // Clean Up defer func() { s.mu.Lock() delete(tunnelConn.streams, streamID) close(responseChan) s.mu.Unlock() // Send Close closeMsg := types.Message{ Type: types.MessageTypeClose, StreamID: streamID, } _ = tunnelConn.WriteJSON(closeMsg) }() // Read & Send Chunks go func() { buffer := make([]byte, 4096) for { n, err := dataReader.Read(buffer) if err != nil { return } if err := tunnelConn.WriteJSON(types.Message{ Type: types.MessageTypeData, StreamID: streamID, Data: buffer[:n], }); err != nil { return } } }() // Return Response Data for data := range responseChan { if _, err := clientConn.Write(data); err != nil { break } } } // peakData limits how much we read as we only need to determine // the host to figure out whether we should proxy or not. func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) { peek := make([]byte, 8192) n, err := conn.Read(peek) if err != nil { return nil, nil, err } peekedData := peek[:n] combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn) return bytes.NewReader(peekedData), combinedReader, nil } func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() peakReader, allReader, _ := s.peekData(conn) if subdomain := s.extractSubdomain(peakReader); subdomain != "" { s.mu.RLock() tunnelConn, exists := s.tunnels[subdomain] s.mu.RUnlock() if exists { log.Infof("Relaying %s to tunnel", subdomain) s.proxyRawConnection(conn, tunnelConn, allReader) return } } // Otherwise, handle as control server (recreate HTTP request and use net/http) s.handleAsHTTP(conn, allReader) } func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) { // Create HTTP Request & Writer r, err := http.ReadRequest(bufio.NewReader(allReader)) if err != nil { _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) return } w := &connResponseWriter{conn: conn} // Handle Control Endpoints switch r.URL.Path { case "/_conduit/tunnel": s.createTunnel(w, r) case "/_conduit/status": s.getStatus(w, r) default: w.WriteHeader(http.StatusNotFound) } } func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { for { var msg types.Message err := tunnel.ReadJSON(&msg) if err != nil { return } if msg.StreamID == "" { log.Infof("Tunnel %s missing streamID", tunnel.name) continue } switch msg.Type { case types.MessageTypeClose: return case types.MessageTypeData: s.mu.RLock() streamChan, exists := tunnel.streams[msg.StreamID] if !exists { log.Infof("Stream %s does not exist", msg.StreamID) s.mu.RUnlock() continue } select { case streamChan <- msg.Data: case <-time.After(time.Second): log.Infof("Stream %s channel full, dropping data", msg.StreamID) } s.mu.RUnlock() } } } func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Get Tunnel Name tunnelName := r.URL.Query().Get("tunnelName") if tunnelName == "" { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte("Missing tunnelName parameter")) return } // Validate Unique if _, exists := s.tunnels[tunnelName]; exists { w.WriteHeader(http.StatusConflict) _, _ = w.Write([]byte("Tunnel already registered")) return } // Upgrade Connection wsConn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("WebSocket upgrade failed: %v", err) return } // Create & Cache TunnelConnection tunnel := &TunnelConnection{ Conn: wsConn, name: tunnelName, streams: make(map[string]chan []byte), } s.mu.Lock() s.tunnels[tunnelName] = tunnel s.mu.Unlock() log.Infof("Tunnel established: %s", tunnelName) // Keep connection alive and handle cleanup defer func() { s.mu.Lock() delete(s.tunnels, tunnelName) s.mu.Unlock() _ = wsConn.Close() log.Infof("Tunnel closed: %s", tunnelName) }() // Handle tunnel messages s.handleTunnelMessages(tunnel) }