package server import ( "fmt" "io" "log" "net/http" "strings" "sync" "time" "github.com/gorilla/websocket" ) type TunnelConnection struct { conn *websocket.Conn vhost string responses map[string]chan *TunnelResponse mu sync.RWMutex } type TunnelRequest struct { ID string `json:"id"` Method string `json:"method"` URL string `json:"url"` Headers map[string]string `json:"headers"` Body []byte `json:"body"` } type TunnelResponse struct { ID string `json:"id"` StatusCode int `json:"status_code"` Headers map[string]string `json:"headers"` Body []byte `json:"body"` } type Server struct { tunnels map[string]*TunnelConnection mu sync.RWMutex upgrader websocket.Upgrader } func NewServer() *Server { return &Server{ tunnels: make(map[string]*TunnelConnection), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, } } func (s *Server) Start(addr string) error { http.HandleFunc("/", s.handleRequest) http.HandleFunc("/_conduit/tunnel", s.handleTunnel) http.HandleFunc("/_conduit/status", s.handleStatus) log.Printf("Conduit server listening on %s", addr) return http.ListenAndServe(addr, nil) } func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { host := r.Host subdomain := s.extractSubdomain(host) if subdomain == "" { http.Error(w, "Invalid host", http.StatusBadRequest) return } s.mu.RLock() tunnel, exists := s.tunnels[subdomain] s.mu.RUnlock() if !exists { http.Error(w, "Tunnel not found", http.StatusNotFound) return } // Create unique request ID reqID := fmt.Sprintf("%d", time.Now().UnixNano()) // Read request body body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Failed to read request body", http.StatusBadRequest) return } defer r.Body.Close() // Convert headers headers := make(map[string]string) for k, v := range r.Header { if len(v) > 0 { headers[k] = v[0] } } // Create response channel for this request respChan := make(chan *TunnelResponse, 1) tunnel.mu.Lock() tunnel.responses[reqID] = respChan tunnel.mu.Unlock() // Clean up response channel when done defer func() { tunnel.mu.Lock() delete(tunnel.responses, reqID) tunnel.mu.Unlock() }() tunnelReq := &TunnelRequest{ ID: reqID, Method: r.Method, URL: r.URL.String(), Headers: headers, Body: body, } // Send request to tunnel if err := tunnel.conn.WriteJSON(tunnelReq); err != nil { log.Printf("Error sending request to tunnel: %v", err) http.Error(w, "Tunnel communication error", http.StatusServiceUnavailable) return } // Wait for response select { case resp := <-respChan: // Write response headers for k, v := range resp.Headers { w.Header().Set(k, v) } w.WriteHeader(resp.StatusCode) w.Write(resp.Body) case <-time.After(30 * time.Second): http.Error(w, "Tunnel timeout", http.StatusGatewayTimeout) } } func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) { vhost := r.URL.Query().Get("vhost") if vhost == "" { http.Error(w, "Missing vhost parameter", http.StatusBadRequest) return } conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade failed: %v", err) return } tunnel := &TunnelConnection{ conn: conn, vhost: vhost, responses: make(map[string]chan *TunnelResponse), } s.mu.Lock() s.tunnels[vhost] = tunnel s.mu.Unlock() log.Printf("Tunnel established: %s", vhost) // Handle tunnel communication s.handleTunnelConnection(tunnel) } func (s *Server) handleTunnelConnection(tunnel *TunnelConnection) { defer func() { s.mu.Lock() delete(s.tunnels, tunnel.vhost) s.mu.Unlock() tunnel.conn.Close() log.Printf("Tunnel closed: %s", tunnel.vhost) }() // Handle incoming responses from client for { var resp TunnelResponse if err := tunnel.conn.ReadJSON(&resp); err != nil { log.Printf("Error reading response from tunnel %s: %v", tunnel.vhost, err) return } // Find the response channel for this request tunnel.mu.RLock() respChan, exists := tunnel.responses[resp.ID] tunnel.mu.RUnlock() if exists { select { case respChan <- &resp: // Response delivered default: log.Printf("Response channel full for request %s", resp.ID) } } else { log.Printf("No response channel found for request %s", resp.ID) } } } func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { s.mu.RLock() defer s.mu.RUnlock() w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"tunnels": %d, "active": [`, len(s.tunnels)) first := true for vhost := range s.tunnels { if !first { fmt.Fprint(w, ",") } fmt.Fprintf(w, `{"vhost":"%s"}`, vhost) first = false } fmt.Fprint(w, "]}") } func (s *Server) extractSubdomain(host string) string { if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } parts := strings.Split(host, ".") if len(parts) >= 1 { return parts[0] } return "" }