config & auth

This commit is contained in:
Evan Reichard 2025-09-20 16:14:10 -04:00
parent 5d9684b27e
commit 2fba07f4b3
6 changed files with 390 additions and 192 deletions

159
client/client.go Normal file
View File

@ -0,0 +1,159 @@
package client
import (
"fmt"
"net"
"net/url"
"sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/config"
"reichard.io/conduit/types"
)
func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) {
// Parse Server URL
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, err
}
// Parse Scheme
var wsScheme string
switch serverURL.Scheme {
case "https":
wsScheme = "wss"
case "http":
wsScheme = "ws"
default:
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
}
// Create Tunnel Name
if cfg.TunnelName == "" {
cfg.TunnelName = generateTunnelName()
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
}
// Connect Server WS
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
}
return &Tunnel{
name: cfg.TunnelName,
target: cfg.TunnelTarget,
serverURL: serverURL,
serverConn: serverConn,
localConns: make(map[string]net.Conn),
}, nil
}
type Tunnel struct {
name string
target string
serverURL *url.URL
serverConn *websocket.Conn
localConns map[string]net.Conn
mu sync.RWMutex
}
func (t *Tunnel) Start() error {
log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target)
defer t.serverConn.Close()
// Handle Messages
for {
// Read Message
var msg types.Message
err := t.serverConn.ReadJSON(&msg)
if err != nil {
log.Errorf("error reading from tunnel: %v", err)
break
}
switch msg.Type {
case types.MessageTypeData:
localConn, err := t.getLocalConn(msg.StreamID)
if err != nil {
log.Errorf("failed to get local connection: %v", err)
continue
}
// Write data to local connection
if _, err := localConn.Write(msg.Data); err != nil {
log.Errorf("error writing to local connection: %v", err)
localConn.Close()
t.mu.Lock()
delete(t.localConns, msg.StreamID)
t.mu.Unlock()
}
case types.MessageTypeClose:
t.mu.Lock()
if localConn, exists := t.localConns[msg.StreamID]; exists {
localConn.Close()
delete(t.localConns, msg.StreamID)
}
t.mu.Unlock()
}
}
return nil
}
func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) {
// Get Cached Connection
t.mu.RLock()
localConn, exists := t.localConns[streamID]
t.mu.RUnlock()
if exists {
return localConn, nil
}
// Initiate Connection & Cache
localConn, err := net.Dial("tcp", t.target)
if err != nil {
log.Errorf("failed to connect to %s: %v", t.target, err)
return nil, err
}
t.mu.Lock()
t.localConns[streamID] = localConn
t.mu.Unlock()
// Start Response Relay & Return Connection
go t.startResponseRelay(streamID, localConn)
return localConn, nil
}
func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) {
defer func() {
t.mu.Lock()
delete(t.localConns, streamID)
t.mu.Unlock()
localConn.Close()
}()
buffer := make([]byte, 4096)
for {
n, err := localConn.Read(buffer)
if err != nil {
break
}
response := types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}
if err := t.serverConn.WriteJSON(response); err != nil {
break
}
}
}

23
client/name.go Normal file
View File

@ -0,0 +1,23 @@
package client
import (
"fmt"
"math/rand"
)
var colors = []string{
"red", "blue", "green", "yellow", "purple", "orange",
"pink", "brown", "black", "white", "gray", "cyan",
}
var animals = []string{
"cat", "dog", "bird", "fish", "lion", "tiger",
"bear", "wolf", "fox", "deer", "rabbit", "mouse",
}
func generateTunnelName() string {
color := colors[rand.Intn(len(colors))]
animal := animals[rand.Intn(len(animals))]
number := rand.Intn(900) + 100
return fmt.Sprintf("%s-%s-%d", color, animal, number)
}

View File

@ -1,181 +1,39 @@
package cmd package cmd
import ( import (
"fmt"
"net"
"net/url"
"sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"reichard.io/conduit/types" "reichard.io/conduit/client"
"reichard.io/conduit/config"
) )
var serverAddr string
var linkCmd = &cobra.Command{ var linkCmd = &cobra.Command{
Use: "link <name> <host:port>", Use: "link <name> <host:port>",
Short: "Create a tunnel link", Short: "Create a conduit tunnel",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
tunnelName := args[0] // Get Client Config
tunnelTarget := args[1] cfg, err := config.GetClientConfig(cmd.Flags())
if err != nil {
log.Fatal("failed to get client config:", err)
}
// Create Tunnel // Create Tunnel
tunnel, err := NewTunnel(tunnelName, tunnelTarget, serverAddr) tunnel, err := client.NewTunnel(cfg)
if err != nil { if err != nil {
log.Fatal("Failed to start tunnel:", err) log.Fatal("failed to create tunnel:", err)
} }
// Start Tunnel // Start Tunnel
log.Infof("Creating TCP tunnel: %s -> %s\n", tunnelName, tunnelTarget) log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget)
if err := tunnel.Start(); err != nil { if err := tunnel.Start(); err != nil {
log.Fatal("Failed to start tunnel:", err) log.Fatal("failed to start tunnel:", err)
} }
}, },
} }
func init() { func init() {
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address") configDefs := config.GetConfigDefs[config.ClientConfig]()
} for _, d := range configDefs {
linkCmd.Flags().String(d.Key, d.Default, d.Description)
func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) {
// Parse Server URL
serverURL, err := url.Parse(serverAddress)
if err != nil {
return nil, err
}
// Parse Scheme
var wsScheme string
switch serverURL.Scheme {
case "https":
wsScheme = "wss"
case "http":
wsScheme = "ws"
default:
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
}
// Connect Server WS
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName)
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
}
return &Tunnel{
name: tunnelName,
target: tunnelTarget,
serverURL: serverURL,
serverConn: serverConn,
localConns: make(map[string]net.Conn),
}, nil
}
type Tunnel struct {
name string
target string
serverURL *url.URL
serverConn *websocket.Conn
localConns map[string]net.Conn
mu sync.RWMutex
}
func (t *Tunnel) Start() error {
log.Infof("TCP Tunnel active! %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target)
defer t.serverConn.Close()
// Handle Messages
for {
// Read Message
var msg types.Message
err := t.serverConn.ReadJSON(&msg)
if err != nil {
log.Errorf("Error reading from tunnel: %v", err)
break
}
switch msg.Type {
case types.MessageTypeData:
localConn, err := t.getLocalConn(msg.StreamID)
if err != nil {
log.Errorf("Failed to get local connection: %v", err)
continue
}
// Write data to local connection
if _, err := localConn.Write(msg.Data); err != nil {
log.Errorf("Error writing to local connection: %v", err)
localConn.Close()
t.mu.Lock()
delete(t.localConns, msg.StreamID)
t.mu.Unlock()
}
case types.MessageTypeClose:
t.mu.Lock()
if localConn, exists := t.localConns[msg.StreamID]; exists {
localConn.Close()
delete(t.localConns, msg.StreamID)
}
t.mu.Unlock()
}
}
return nil
}
func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) {
// Get Cached Connection
t.mu.RLock()
localConn, exists := t.localConns[streamID]
t.mu.RUnlock()
if exists {
return localConn, nil
}
// Initiate Connection & Cache
localConn, err := net.Dial("tcp", t.target)
if err != nil {
log.Errorf("Failed to connect to %s: %v", t.target, err)
return nil, err
}
t.mu.Lock()
t.localConns[streamID] = localConn
t.mu.Unlock()
// Start Response Relay & Return Connection
go t.startResponseRelay(streamID, localConn)
return localConn, nil
}
func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) {
defer func() {
t.mu.Lock()
delete(t.localConns, streamID)
t.mu.Unlock()
localConn.Close()
}()
buffer := make([]byte, 4096)
for {
n, err := localConn.Read(buffer)
if err != nil {
break
}
response := types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}
if err := t.serverConn.WriteJSON(response); err != nil {
break
}
} }
} }

View File

@ -1,30 +1,35 @@
package cmd package cmd
import ( import (
"fmt" "reichard.io/conduit/config"
"log"
"reichard.io/conduit/server" "reichard.io/conduit/server"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var port string
var serveCmd = &cobra.Command{ var serveCmd = &cobra.Command{
Use: "serve", Use: "serve",
Short: "Start the conduit server", Short: "Start the conduit server",
Long: `Start the conduit server to handle incoming tunnel requests`, Long: `Start the conduit server to handle incoming tunnel requests`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Starting Conduit server on port %s...\n", port) // Get Server Config
cfg, err := config.GetServerConfig(cmd.Flags())
if err != nil {
log.Fatal("failed to get server config:", err)
}
srv := server.NewServer() // Start Server
if err := srv.Start(":" + port); err != nil { srv := server.NewServer(cfg)
log.Fatal("Failed to start server:", err) if err := srv.Start(); err != nil {
log.Fatal("failed to start server:", err)
} }
}, },
} }
func init() { func init() {
serveCmd.Flags().StringVarP(&port, "port", "p", "8080", "Port to run the server on") configDefs := config.GetConfigDefs[config.ServerConfig]()
for _, d := range configDefs {
serveCmd.Flags().String(d.Key, d.Default, d.Description)
}
} }

View File

@ -1 +1,145 @@
package config package config
import (
"fmt"
"os"
"reflect"
"strings"
"github.com/spf13/pflag"
)
type ConfigDef struct {
Key string
Env string
Default string
Description string
}
type BaseConfig struct {
ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"`
APIKey string `json:"api_key" description:"API Key for the conduit API"`
}
func (c *BaseConfig) Validate() error {
if c.APIKey == "" {
return fmt.Errorf("api_key is required")
}
return nil
}
type ServerConfig struct {
BaseConfig
BindAddress string `json:"bind" default:"0.0.0.0:8080" description:"Address the conduit server listens on"`
}
type ClientConfig struct {
BaseConfig
TunnelName string `json:"name" description:"Tunnel name"`
TunnelTarget string `json:"target" description:"Tunnel target address"`
}
func (c *ClientConfig) Validate() error {
if err := c.BaseConfig.Validate(); err != nil {
return err
}
if c.TunnelTarget == "" {
return fmt.Errorf("target is required")
}
return nil
}
func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
defs := GetConfigDefs[ServerConfig]()
cfgValues := make(map[string]string)
for _, def := range defs {
cfgValues[def.Key] = getConfigValue(cmdFlags, def)
}
cfg := &ServerConfig{
BaseConfig: BaseConfig{
ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"],
},
BindAddress: cfgValues["bind"],
}
return cfg, cfg.Validate()
}
func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
defs := GetConfigDefs[ClientConfig]()
cfgValues := make(map[string]string)
for _, def := range defs {
cfgValues[def.Key] = getConfigValue(cmdFlags, def)
}
cfg := &ClientConfig{
BaseConfig: BaseConfig{
ServerAddress: cfgValues["address"],
APIKey: cfgValues["api_key"],
},
TunnelName: cfgValues["name"],
TunnelTarget: cfgValues["target"],
}
return cfg, cfg.Validate()
}
func GetConfigDefs[T ServerConfig | ClientConfig]() []ConfigDef {
var defs []ConfigDef
processFields(reflect.TypeFor[T](), &defs)
return defs
}
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
// 1. Get Flags First
if cmdFlags != nil {
if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" {
return val
}
}
// 2. Environment Variables Next
if envVal := os.Getenv(def.Env); envVal != "" {
return envVal
}
// 3. Defaults Last
return def.Default
}
func processFields(t reflect.Type, defs *[]ConfigDef) {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Process Embedded (BaseConfig)
if field.Anonymous {
processFields(field.Type, defs)
continue
}
// Extract Struct Tags
jsonTag := field.Tag.Get("json")
defaultTag := field.Tag.Get("default")
descriptionTag := field.Tag.Get("description")
// Skip JSON Fields
if jsonTag == "" {
continue
}
// Get Key & Env
key := strings.Split(jsonTag, ",")[0]
env := "CONDUIT_" + strings.ToUpper(key)
*defs = append(*defs, ConfigDef{
Key: key,
Env: env,
Default: defaultTag,
Description: descriptionTag,
})
}
}

View File

@ -13,23 +13,26 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/conduit/config"
"reichard.io/conduit/types" "reichard.io/conduit/types"
) )
type TunnelConnection struct { type TunnelConnection struct {
*websocket.Conn *websocket.Conn
name string name string
streams map[string]chan []byte // StreamID -> data channel streams map[string]chan []byte
} }
type Server struct { type Server struct {
tunnels map[string]*TunnelConnection tunnels map[string]*TunnelConnection
upgrader websocket.Upgrader upgrader websocket.Upgrader
cfg *config.ServerConfig
mu sync.RWMutex mu sync.RWMutex
} }
func NewServer() *Server { func NewServer(cfg *config.ServerConfig) *Server {
return &Server{ return &Server{
cfg: cfg,
tunnels: make(map[string]*TunnelConnection), tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
@ -39,20 +42,21 @@ func NewServer() *Server {
} }
} }
func (s *Server) Start(addr string) error { func (s *Server) Start() error {
// Raw TCP listener instead of http.ListenAndServe // Raw TCP Listener - This is necessary so we can conditionally either relay
listener, err := net.Listen("tcp", addr) // the raw TCP connection, or handle conduit control server API requests.
listener, err := net.Listen("tcp", s.cfg.BindAddress)
if err != nil { if err != nil {
return err return err
} }
defer listener.Close() defer listener.Close()
log.Infof("Conduit server listening on %s", addr) // Start Listening
log.Infof("conduit server listening on %s", s.cfg.BindAddress)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Printf("Error accepting connection: %v", err) log.Printf("error accepting connection: %v", err)
continue continue
} }
@ -76,7 +80,7 @@ func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Extract Subdomain // Extract Subdomain
parts := strings.Split(host, ".") parts := strings.Split(host, ".")
if len(parts) >= 1 { if len(parts) > 1 {
return parts[0] return parts[0]
} }
@ -152,9 +156,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
} }
} }
// peakData limits how much we read as we only need to determine func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
// the host to figure out whether we should proxy or not.
func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
peek := make([]byte, 8192) peek := make([]byte, 8192)
n, err := conn.Read(peek) n, err := conn.Read(peek)
if err != nil { if err != nil {
@ -163,13 +165,13 @@ func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
peekedData := peek[:n] peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn) combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil return bytes.NewReader(peekedData), combinedReader, nil
} }
func (s *Server) handleRawConnection(conn net.Conn) { func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close() defer conn.Close()
// Detect Tunnel
peakReader, allReader, _ := s.peekData(conn) peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" { if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock() s.mu.RLock()
@ -177,25 +179,32 @@ func (s *Server) handleRawConnection(conn net.Conn) {
s.mu.RUnlock() s.mu.RUnlock()
if exists { if exists {
log.Infof("Relaying %s to tunnel", subdomain) log.Infof("relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader) s.proxyRawConnection(conn, tunnelConn, allReader)
}
return return
} }
}
// Otherwise, handle as control server (recreate HTTP request and use net/http) // Control Endpoints
s.handleAsHTTP(conn, allReader) s.handleAsHTTP(conn, allReader)
} }
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) { func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer // Create HTTP Request & Writer
w := &connResponseWriter{conn: conn}
r, err := http.ReadRequest(bufio.NewReader(allReader)) r, err := http.ReadRequest(bufio.NewReader(allReader))
if err != nil { if err != nil {
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) w.WriteHeader(http.StatusBadRequest)
return
}
// Authorize Control Endpoints
apiKey := r.URL.Query().Get("apiKey")
if apiKey != s.cfg.APIKey {
log.Error("unauthorized client")
w.WriteHeader(http.StatusUnauthorized)
return return
} }
w := &connResponseWriter{conn: conn}
// Handle Control Endpoints // Handle Control Endpoints
switch r.URL.Path { switch r.URL.Path {
@ -217,7 +226,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
} }
if msg.StreamID == "" { if msg.StreamID == "" {
log.Infof("Tunnel %s missing streamID", tunnel.name) log.Infof("tunnel %s missing streamID", tunnel.name)
continue continue
} }
@ -228,7 +237,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
s.mu.RLock() s.mu.RLock()
streamChan, exists := tunnel.streams[msg.StreamID] streamChan, exists := tunnel.streams[msg.StreamID]
if !exists { if !exists {
log.Infof("Stream %s does not exist", msg.StreamID) log.Infof("stream %s does not exist", msg.StreamID)
s.mu.RUnlock() s.mu.RUnlock()
continue continue
} }
@ -236,7 +245,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
select { select {
case streamChan <- msg.Data: case streamChan <- msg.Data:
case <-time.After(time.Second): case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID) log.Warnf("stream %s channel full, dropping data", msg.StreamID)
} }
s.mu.RUnlock() s.mu.RUnlock()
} }
@ -261,7 +270,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Upgrade Connection // Upgrade Connection
wsConn, err := s.upgrader.Upgrade(w, r, nil) wsConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err) log.Errorf("websocket upgrade failed: %v", err)
return return
} }
@ -274,7 +283,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
s.mu.Lock() s.mu.Lock()
s.tunnels[tunnelName] = tunnel s.tunnels[tunnelName] = tunnel
s.mu.Unlock() s.mu.Unlock()
log.Infof("Tunnel established: %s", tunnelName) log.Infof("tunnel established: %s", tunnelName)
// Keep connection alive and handle cleanup // Keep connection alive and handle cleanup
defer func() { defer func() {
@ -282,7 +291,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
delete(s.tunnels, tunnelName) delete(s.tunnels, tunnelName)
s.mu.Unlock() s.mu.Unlock()
_ = wsConn.Close() _ = wsConn.Close()
log.Infof("Tunnel closed: %s", tunnelName) log.Infof("tunnel closed: %s", tunnelName)
}() }()
// Handle tunnel messages // Handle tunnel messages