This commit is contained in:
Evan Reichard 2025-09-20 18:12:56 -04:00
parent 2fba07f4b3
commit 08e1191ba3
3 changed files with 78 additions and 63 deletions

View File

@ -19,8 +19,13 @@ var serveCmd = &cobra.Command{
log.Fatal("failed to get server config:", err)
}
// Create Server
srv, err := server.NewServer(cfg)
if err != nil {
log.Fatal("failed to create server:", err)
}
// Start Server
srv := server.NewServer(cfg)
if err := srv.Start(); err != nil {
log.Fatal("failed to start server:", err)
}

View File

@ -1,7 +1,9 @@
package config
import (
"errors"
"fmt"
"net/url"
"os"
"reflect"
"strings"
@ -17,13 +19,19 @@ type ConfigDef struct {
}
type BaseConfig struct {
ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"`
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
APIKey string `json:"api_key" description:"API Key for the conduit API"`
}
func (c *BaseConfig) Validate() error {
if c.APIKey == "" {
return fmt.Errorf("api_key is required")
return errors.New("api_key is required")
}
if c.ServerAddress == "" {
return errors.New("server is required")
}
if _, err := url.Parse(c.ServerAddress); err != nil {
return fmt.Errorf("server is invalid: %w", err)
}
return nil
}
@ -78,7 +86,7 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
cfg := &ClientConfig{
BaseConfig: BaseConfig{
ServerAddress: cfgValues["address"],
ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"],
},
TunnelName: cfgValues["name"],

View File

@ -3,10 +3,12 @@ package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
@ -24,22 +26,32 @@ type TunnelConnection struct {
}
type Server struct {
tunnels map[string]*TunnelConnection
upgrader websocket.Upgrader
host string
cfg *config.ServerConfig
mu sync.RWMutex
upgrader websocket.Upgrader
tunnels map[string]*TunnelConnection
}
func NewServer(cfg *config.ServerConfig) *Server {
func NewServer(cfg *config.ServerConfig) (*Server, error) {
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, fmt.Errorf("failed to parse server address: %v", err)
} else if serverURL.Host == "" {
return nil, errors.New("invalid server address")
}
return &Server{
cfg: cfg,
host: serverURL.Host,
tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}, nil
}
func (s *Server) Start() error {
@ -64,29 +76,6 @@ func (s *Server) Start() error {
}
}
func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Read Request
req, err := http.ReadRequest(bufio.NewReader(peakReader))
if err != nil {
return ""
}
defer req.Body.Close()
// Extract Host
host := req.Host
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) > 1 {
return parts[0]
}
return ""
}
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
count := len(s.tunnels)
@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
}
}
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
return nil, nil, err
}
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
// Detect Tunnel
peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
// Capture Consumed Data - When determining where to route the request, we
// have to read the host headers. This requires reading from the buffer, so
// if we later decide to tunnel the TCP connection we need to reconstruct the
// data from the buffer.
var capturedData bytes.Buffer
teeReader := io.TeeReader(conn, &capturedData)
bufReader := bufio.NewReader(teeReader)
if exists {
log.Infof("relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
return
}
// Control Endpoints
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer
w := &connResponseWriter{conn: conn}
r, err := http.ReadRequest(bufio.NewReader(allReader))
r, err := http.ReadRequest(bufReader)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
defer r.Body.Close()
// Validate Host
if !strings.Contains(r.Host, s.host) {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
return
}
// Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return
}
// Handle Control Endpoints
if subdomain == "" {
s.handleAsHTTP(w, r)
return
}
// Handle Tunnels
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("relaying %s to tunnel", subdomain)
// Reconstruct Data & Proxy Connection
allReader := io.MultiReader(&capturedData, r.Body)
s.proxyRawConnection(conn, tunnelConn, allReader)
}
}
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
// Authorize Control Endpoints
apiKey := r.URL.Query().Get("apiKey")
if apiKey != s.cfg.APIKey {