conduit/server/server.go
2025-09-19 17:22:36 -04:00

337 lines
7.2 KiB
Go

package server
import (
"bufio"
"bytes"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
type TunnelConnection struct {
*websocket.Conn
vhost string
streams map[string]chan []byte // StreamID -> data channel
}
type TunnelMessage struct {
Type string `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
}
type Server struct {
tunnels map[string]*TunnelConnection
mu sync.RWMutex
upgrader websocket.Upgrader
}
func NewServer() *Server {
return &Server{
tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
func (s *Server) Start(addr string) error {
// Raw TCP listener instead of http.ListenAndServe
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
log.Infof("Conduit server listening on %s", addr)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v", err)
continue
}
go s.handleRawConnection(conn)
}
}
func (s *Server) extractSubdomain(host string) string {
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
parts := strings.Split(host, ".")
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func (s *Server) handleStatus(conn net.Conn) {
s.mu.RLock()
count := len(s.tunnels)
s.mu.RUnlock()
response := fmt.Sprintf(
"HTTP/1.1 200 OK\r\n"+
"Content-Type: application/json\r\n"+
"Content-Length: %d\r\n\r\n"+
`{"tunnels": %d}`,
len(fmt.Sprintf(`{"tunnels": %d}`, count)), count)
conn.Write([]byte(response))
}
func (s *Server) extractHostFromHTTP(data []byte) string {
// Simple HTTP header parsing
lines := strings.Split(string(data), "\r\n")
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "host:") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[1])
}
}
}
return ""
}
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, initialData []byte) {
defer clientConn.Close()
// Generate a unique stream ID for this connection
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
// Send initial data with stream ID
msg := TunnelMessage{
Type: "data",
StreamID: streamID,
Data: initialData,
}
if err := tunnelConn.WriteJSON(msg); err != nil {
log.Errorf("Error sending initial data: %v", err)
return
}
// Create a channel for this stream's responses
responseChan := make(chan []byte, 100)
// Register this stream
s.mu.Lock()
if tunnelConn.streams == nil {
tunnelConn.streams = make(map[string]chan []byte)
}
tunnelConn.streams[streamID] = responseChan
s.mu.Unlock()
// Clean up when done
defer func() {
s.mu.Lock()
delete(tunnelConn.streams, streamID)
close(responseChan)
s.mu.Unlock()
// Send close message
closeMsg := TunnelMessage{
Type: "close",
StreamID: streamID,
}
tunnelConn.WriteJSON(closeMsg)
}()
// Handle client -> tunnel
go func() {
buffer := make([]byte, 4096)
for {
n, err := clientConn.Read(buffer)
if err != nil {
return
}
msg := TunnelMessage{
Type: "data",
StreamID: streamID,
Data: buffer[:n],
}
if err := tunnelConn.WriteJSON(msg); err != nil {
return
}
}
}()
// Handle tunnel -> client (read from response channel)
for data := range responseChan {
if _, err := clientConn.Write(data); err != nil {
break
}
}
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
// Read enough to get the Host header
buffer := make([]byte, 4096)
n, err := conn.Read(buffer)
if err != nil {
return
}
// Extract host
host := s.extractHostFromHTTP(buffer[:n])
subdomain := s.extractSubdomain(host)
// If we have a registered tunnel for this subdomain, relay it
if subdomain != "" {
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("Relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, buffer[:n])
return
}
}
// Otherwise, handle as control server (recreate HTTP request and use net/http)
s.handleAsHTTP(conn, buffer[:n])
}
func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) {
// Create a fake HTTP request from the raw data
reader := bufio.NewReader(bytes.NewReader(initialData))
req, err := http.ReadRequest(reader)
if err != nil {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
return
}
// Handle Control Endpoints
switch req.URL.Path {
case "/_conduit/tunnel":
s.handleTunnelUpgrade(conn, req)
return
case "/_conduit/status":
s.handleStatus(conn)
default:
conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n"))
}
}
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
for {
var msg TunnelMessage
err := tunnel.ReadJSON(&msg)
if err != nil {
break
}
// Route message to appropriate stream
if msg.Type == "data" && msg.StreamID != "" {
s.mu.RLock()
if streamChan, exists := tunnel.streams[msg.StreamID]; exists {
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
}
}
s.mu.RUnlock()
}
}
}
func (s *Server) handleTunnelUpgrade(conn net.Conn, req *http.Request) {
vhost := req.URL.Query().Get("vhost")
if vhost == "" {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing vhost parameter"))
return
}
// Create a fake ResponseWriter that writes to our raw connection
fakeWriter := &fakeResponseWriter{conn: conn}
// Use the upgrader
wsConn, err := s.upgrader.Upgrade(fakeWriter, req, nil)
if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err)
return
}
// Create TunnelConnection
tunnel := &TunnelConnection{
Conn: wsConn,
vhost: vhost,
streams: make(map[string]chan []byte),
}
s.mu.Lock()
s.tunnels[vhost] = tunnel
s.mu.Unlock()
log.Infof("Tunnel established: %s", vhost)
// Keep connection alive and handle cleanup
defer func() {
s.mu.Lock()
delete(s.tunnels, vhost)
s.mu.Unlock()
wsConn.Close()
log.Infof("Tunnel closed: %s", vhost)
}()
// Handle tunnel messages
s.handleTunnelMessages(tunnel)
}
type fakeResponseWriter struct {
conn net.Conn
header http.Header
}
func (f *fakeResponseWriter) Header() http.Header {
if f.header == nil {
f.header = make(http.Header)
}
return f.header
}
func (f *fakeResponseWriter) Write(data []byte) (int, error) {
return f.conn.Write(data)
}
func (f *fakeResponseWriter) WriteHeader(statusCode int) {
// Write HTTP status line
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
f.conn.Write([]byte(status))
// Write headers
for key, values := range f.header {
for _, value := range values {
f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value)))
}
}
// End headers
f.conn.Write([]byte("\r\n"))
}
// Implement http.Hijacker
func (f *fakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Return the raw connection and create a ReadWriter for it
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil
}