conduit/server/server.go
2025-09-19 23:53:41 -04:00

291 lines
6.1 KiB
Go

package server
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/types"
)
type TunnelConnection struct {
*websocket.Conn
name string
streams map[string]chan []byte // StreamID -> data channel
}
type Server struct {
tunnels map[string]*TunnelConnection
upgrader websocket.Upgrader
mu sync.RWMutex
}
func NewServer() *Server {
return &Server{
tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
func (s *Server) Start(addr string) error {
// Raw TCP listener instead of http.ListenAndServe
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
defer listener.Close()
log.Infof("Conduit server listening on %s", addr)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v", err)
continue
}
go s.handleRawConnection(conn)
}
}
func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Read Request
req, err := http.ReadRequest(bufio.NewReader(peakReader))
if err != nil {
return ""
}
defer req.Body.Close()
// Extract Host
host := req.Host
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) >= 1 {
return parts[0]
}
return ""
}
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
count := len(s.tunnels)
s.mu.RUnlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
response := fmt.Sprintf(`{"tunnels": %d}`, count)
_, _ = w.Write([]byte(response))
}
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) {
defer clientConn.Close()
// Create Identifiers
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
responseChan := make(chan []byte, 100)
// Register Stream
s.mu.Lock()
if tunnelConn.streams == nil {
tunnelConn.streams = make(map[string]chan []byte)
}
tunnelConn.streams[streamID] = responseChan
s.mu.Unlock()
// Clean Up
defer func() {
s.mu.Lock()
delete(tunnelConn.streams, streamID)
close(responseChan)
s.mu.Unlock()
// Send Close
closeMsg := types.Message{
Type: types.MessageTypeClose,
StreamID: streamID,
}
_ = tunnelConn.WriteJSON(closeMsg)
}()
// Read & Send Chunks
go func() {
buffer := make([]byte, 4096)
for {
n, err := dataReader.Read(buffer)
if err != nil {
return
}
if err := tunnelConn.WriteJSON(types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}); err != nil {
return
}
}
}()
// Return Response Data
for data := range responseChan {
if _, err := clientConn.Write(data); err != nil {
break
}
}
}
// peakData limits how much we read as we only need to determine
// the host to figure out whether we should proxy or not.
func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
return nil, nil, err
}
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("Relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader)
return
}
}
// Otherwise, handle as control server (recreate HTTP request and use net/http)
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer
r, err := http.ReadRequest(bufio.NewReader(allReader))
if err != nil {
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
return
}
w := &connResponseWriter{conn: conn}
// Handle Control Endpoints
switch r.URL.Path {
case "/_conduit/tunnel":
s.createTunnel(w, r)
case "/_conduit/status":
s.getStatus(w, r)
default:
w.WriteHeader(http.StatusNotFound)
}
}
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
for {
var msg types.Message
err := tunnel.ReadJSON(&msg)
if err != nil {
return
}
if msg.StreamID == "" {
log.Infof("Tunnel %s missing streamID", tunnel.name)
continue
}
switch msg.Type {
case types.MessageTypeClose:
return
case types.MessageTypeData:
s.mu.RLock()
streamChan, exists := tunnel.streams[msg.StreamID]
if !exists {
log.Infof("Stream %s does not exist", msg.StreamID)
s.mu.RUnlock()
continue
}
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
}
s.mu.RUnlock()
}
}
}
func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Get Tunnel Name
tunnelName := r.URL.Query().Get("tunnelName")
if tunnelName == "" {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("Missing tunnelName parameter"))
return
}
// Validate Unique
if _, exists := s.tunnels[tunnelName]; exists {
w.WriteHeader(http.StatusConflict)
_, _ = w.Write([]byte("Tunnel already registered"))
return
}
// Upgrade Connection
wsConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err)
return
}
// Create & Cache TunnelConnection
tunnel := &TunnelConnection{
Conn: wsConn,
name: tunnelName,
streams: make(map[string]chan []byte),
}
s.mu.Lock()
s.tunnels[tunnelName] = tunnel
s.mu.Unlock()
log.Infof("Tunnel established: %s", tunnelName)
// Keep connection alive and handle cleanup
defer func() {
s.mu.Lock()
delete(s.tunnels, tunnelName)
s.mu.Unlock()
_ = wsConn.Close()
log.Infof("Tunnel closed: %s", tunnelName)
}()
// Handle tunnel messages
s.handleTunnelMessages(tunnel)
}