This commit is contained in:
2025-09-20 18:12:56 -04:00
parent 2fba07f4b3
commit 08e1191ba3
3 changed files with 78 additions and 63 deletions

View File

@@ -3,10 +3,12 @@ package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
@@ -24,22 +26,32 @@ type TunnelConnection struct {
}
type Server struct {
tunnels map[string]*TunnelConnection
host string
cfg *config.ServerConfig
mu sync.RWMutex
upgrader websocket.Upgrader
cfg *config.ServerConfig
mu sync.RWMutex
tunnels map[string]*TunnelConnection
}
func NewServer(cfg *config.ServerConfig) *Server {
func NewServer(cfg *config.ServerConfig) (*Server, error) {
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, fmt.Errorf("failed to parse server address: %v", err)
} else if serverURL.Host == "" {
return nil, errors.New("invalid server address")
}
return &Server{
cfg: cfg,
host: serverURL.Host,
tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}, nil
}
func (s *Server) Start() error {
@@ -64,29 +76,6 @@ func (s *Server) Start() error {
}
}
func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Read Request
req, err := http.ReadRequest(bufio.NewReader(peakReader))
if err != nil {
return ""
}
defer req.Body.Close()
// Extract Host
host := req.Host
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) > 1 {
return parts[0]
}
return ""
}
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
count := len(s.tunnels)
@@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
}
}
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
return nil, nil, err
}
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
// Detect Tunnel
peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
// Capture Consumed Data - When determining where to route the request, we
// have to read the host headers. This requires reading from the buffer, so
// if we later decide to tunnel the TCP connection we need to reconstruct the
// data from the buffer.
var capturedData bytes.Buffer
teeReader := io.TeeReader(conn, &capturedData)
bufReader := bufio.NewReader(teeReader)
if exists {
log.Infof("relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
return
}
// Control Endpoints
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer
w := &connResponseWriter{conn: conn}
r, err := http.ReadRequest(bufio.NewReader(allReader))
r, err := http.ReadRequest(bufReader)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
defer r.Body.Close()
// Validate Host
if !strings.Contains(r.Host, s.host) {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
return
}
// Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return
}
// Handle Control Endpoints
if subdomain == "" {
s.handleAsHTTP(w, r)
return
}
// Handle Tunnels
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("relaying %s to tunnel", subdomain)
// Reconstruct Data & Proxy Connection
allReader := io.MultiReader(&capturedData, r.Body)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
}
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
// Authorize Control Endpoints
apiKey := r.URL.Query().Get("apiKey")
if apiKey != s.cfg.APIKey {