This commit is contained in:
Evan Reichard 2025-09-20 18:12:56 -04:00
parent 2fba07f4b3
commit 08e1191ba3
3 changed files with 78 additions and 63 deletions

View File

@ -19,8 +19,13 @@ var serveCmd = &cobra.Command{
log.Fatal("failed to get server config:", err) log.Fatal("failed to get server config:", err)
} }
// Create Server
srv, err := server.NewServer(cfg)
if err != nil {
log.Fatal("failed to create server:", err)
}
// Start Server // Start Server
srv := server.NewServer(cfg)
if err := srv.Start(); err != nil { if err := srv.Start(); err != nil {
log.Fatal("failed to start server:", err) log.Fatal("failed to start server:", err)
} }

View File

@ -1,7 +1,9 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"net/url"
"os" "os"
"reflect" "reflect"
"strings" "strings"
@ -17,13 +19,19 @@ type ConfigDef struct {
} }
type BaseConfig struct { type BaseConfig struct {
ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"` ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
APIKey string `json:"api_key" description:"API Key for the conduit API"` APIKey string `json:"api_key" description:"API Key for the conduit API"`
} }
func (c *BaseConfig) Validate() error { func (c *BaseConfig) Validate() error {
if c.APIKey == "" { if c.APIKey == "" {
return fmt.Errorf("api_key is required") return errors.New("api_key is required")
}
if c.ServerAddress == "" {
return errors.New("server is required")
}
if _, err := url.Parse(c.ServerAddress); err != nil {
return fmt.Errorf("server is invalid: %w", err)
} }
return nil return nil
} }
@ -78,7 +86,7 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
cfg := &ClientConfig{ cfg := &ClientConfig{
BaseConfig: BaseConfig{ BaseConfig: BaseConfig{
ServerAddress: cfgValues["address"], ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"], APIKey: cfgValues["api_key"],
}, },
TunnelName: cfgValues["name"], TunnelName: cfgValues["name"],

View File

@ -3,10 +3,12 @@ package server
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -24,22 +26,32 @@ type TunnelConnection struct {
} }
type Server struct { type Server struct {
tunnels map[string]*TunnelConnection host string
upgrader websocket.Upgrader
cfg *config.ServerConfig cfg *config.ServerConfig
mu sync.RWMutex mu sync.RWMutex
upgrader websocket.Upgrader
tunnels map[string]*TunnelConnection
}
func NewServer(cfg *config.ServerConfig) (*Server, error) {
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, fmt.Errorf("failed to parse server address: %v", err)
} else if serverURL.Host == "" {
return nil, errors.New("invalid server address")
} }
func NewServer(cfg *config.ServerConfig) *Server {
return &Server{ return &Server{
cfg: cfg, cfg: cfg,
host: serverURL.Host,
tunnels: make(map[string]*TunnelConnection), tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
}, },
}, },
} }, nil
} }
func (s *Server) Start() error { func (s *Server) Start() error {
@ -64,29 +76,6 @@ func (s *Server) Start() error {
} }
} }
func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Read Request
req, err := http.ReadRequest(bufio.NewReader(peakReader))
if err != nil {
return ""
}
defer req.Body.Close()
// Extract Host
host := req.Host
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) > 1 {
return parts[0]
}
return ""
}
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) { func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock() s.mu.RLock()
count := len(s.tunnels) count := len(s.tunnels)
@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
} }
} }
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
return nil, nil, err
}
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) { func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close() defer conn.Close()
// Detect Tunnel // Capture Consumed Data - When determining where to route the request, we
peakReader, allReader, _ := s.peekData(conn) // have to read the host headers. This requires reading from the buffer, so
if subdomain := s.extractSubdomain(peakReader); subdomain != "" { // if we later decide to tunnel the TCP connection we need to reconstruct the
s.mu.RLock() // data from the buffer.
tunnelConn, exists := s.tunnels[subdomain] var capturedData bytes.Buffer
s.mu.RUnlock() teeReader := io.TeeReader(conn, &capturedData)
bufReader := bufio.NewReader(teeReader)
if exists {
log.Infof("relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
return
}
// Control Endpoints
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer // Create HTTP Request & Writer
w := &connResponseWriter{conn: conn} w := &connResponseWriter{conn: conn}
r, err := http.ReadRequest(bufio.NewReader(allReader)) r, err := http.ReadRequest(bufReader)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
defer r.Body.Close()
// Validate Host
if !strings.Contains(r.Host, s.host) {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
return
}
// Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return
}
// Handle Control Endpoints
if subdomain == "" {
s.handleAsHTTP(w, r)
return
}
// Handle Tunnels
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("relaying %s to tunnel", subdomain)
// Reconstruct Data & Proxy Connection
allReader := io.MultiReader(&capturedData, r.Body)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
}
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
// Authorize Control Endpoints // Authorize Control Endpoints
apiKey := r.URL.Query().Get("apiKey") apiKey := r.URL.Query().Get("apiKey")
if apiKey != s.cfg.APIKey { if apiKey != s.cfg.APIKey {