diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..a9786d2 --- /dev/null +++ b/client/client.go @@ -0,0 +1,159 @@ +package client + +import ( + "fmt" + "net" + "net/url" + "sync" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + "reichard.io/conduit/config" + "reichard.io/conduit/types" +) + +func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) { + // Parse Server URL + serverURL, err := url.Parse(cfg.ServerAddress) + if err != nil { + return nil, err + } + + // Parse Scheme + var wsScheme string + switch serverURL.Scheme { + case "https": + wsScheme = "wss" + case "http": + wsScheme = "ws" + default: + return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) + } + + // Create Tunnel Name + if cfg.TunnelName == "" { + cfg.TunnelName = generateTunnelName() + log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName) + } + + // Connect Server WS + wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey) + serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } + + return &Tunnel{ + name: cfg.TunnelName, + target: cfg.TunnelTarget, + serverURL: serverURL, + serverConn: serverConn, + localConns: make(map[string]net.Conn), + }, nil + +} + +type Tunnel struct { + name string + target string + serverURL *url.URL + + serverConn *websocket.Conn + localConns map[string]net.Conn + mu sync.RWMutex +} + +func (t *Tunnel) Start() error { + log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target) + defer t.serverConn.Close() + + // Handle Messages + for { + // Read Message + var msg types.Message + err := t.serverConn.ReadJSON(&msg) + if err != nil { + log.Errorf("error reading from tunnel: %v", err) + break + } + + switch msg.Type { + case types.MessageTypeData: + localConn, err := t.getLocalConn(msg.StreamID) + if err != nil { + log.Errorf("failed to get local connection: %v", err) + continue + } + + // Write data to local connection + if _, err := localConn.Write(msg.Data); err != nil { + log.Errorf("error writing to local connection: %v", err) + localConn.Close() + t.mu.Lock() + delete(t.localConns, msg.StreamID) + t.mu.Unlock() + } + + case types.MessageTypeClose: + t.mu.Lock() + if localConn, exists := t.localConns[msg.StreamID]; exists { + localConn.Close() + delete(t.localConns, msg.StreamID) + } + t.mu.Unlock() + } + } + + return nil +} + +func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { + // Get Cached Connection + t.mu.RLock() + localConn, exists := t.localConns[streamID] + t.mu.RUnlock() + if exists { + return localConn, nil + } + + // Initiate Connection & Cache + localConn, err := net.Dial("tcp", t.target) + if err != nil { + log.Errorf("failed to connect to %s: %v", t.target, err) + return nil, err + } + t.mu.Lock() + t.localConns[streamID] = localConn + t.mu.Unlock() + + // Start Response Relay & Return Connection + go t.startResponseRelay(streamID, localConn) + return localConn, nil +} + +func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) { + defer func() { + t.mu.Lock() + delete(t.localConns, streamID) + t.mu.Unlock() + localConn.Close() + }() + + buffer := make([]byte, 4096) + for { + n, err := localConn.Read(buffer) + if err != nil { + break + } + + response := types.Message{ + Type: types.MessageTypeData, + StreamID: streamID, + Data: buffer[:n], + } + + if err := t.serverConn.WriteJSON(response); err != nil { + break + } + } +} diff --git a/client/name.go b/client/name.go new file mode 100644 index 0000000..97aba3e --- /dev/null +++ b/client/name.go @@ -0,0 +1,23 @@ +package client + +import ( + "fmt" + "math/rand" +) + +var colors = []string{ + "red", "blue", "green", "yellow", "purple", "orange", + "pink", "brown", "black", "white", "gray", "cyan", +} + +var animals = []string{ + "cat", "dog", "bird", "fish", "lion", "tiger", + "bear", "wolf", "fox", "deer", "rabbit", "mouse", +} + +func generateTunnelName() string { + color := colors[rand.Intn(len(colors))] + animal := animals[rand.Intn(len(animals))] + number := rand.Intn(900) + 100 + return fmt.Sprintf("%s-%s-%d", color, animal, number) +} diff --git a/cmd/link.go b/cmd/link.go index e3f75c2..53802c8 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -1,181 +1,39 @@ package cmd import ( - "fmt" - "net" - "net/url" - "sync" - - "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "reichard.io/conduit/types" + "reichard.io/conduit/client" + "reichard.io/conduit/config" ) -var serverAddr string - var linkCmd = &cobra.Command{ Use: "link ", - Short: "Create a tunnel link", - Args: cobra.ExactArgs(2), + Short: "Create a conduit tunnel", Run: func(cmd *cobra.Command, args []string) { - tunnelName := args[0] - tunnelTarget := args[1] + // Get Client Config + cfg, err := config.GetClientConfig(cmd.Flags()) + if err != nil { + log.Fatal("failed to get client config:", err) + } // Create Tunnel - tunnel, err := NewTunnel(tunnelName, tunnelTarget, serverAddr) + tunnel, err := client.NewTunnel(cfg) if err != nil { - log.Fatal("Failed to start tunnel:", err) + log.Fatal("failed to create tunnel:", err) } // Start Tunnel - log.Infof("Creating TCP tunnel: %s -> %s\n", tunnelName, tunnelTarget) + log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget) if err := tunnel.Start(); err != nil { - log.Fatal("Failed to start tunnel:", err) + log.Fatal("failed to start tunnel:", err) } }, } func init() { - linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address") -} - -func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) { - // Parse Server URL - serverURL, err := url.Parse(serverAddress) - if err != nil { - return nil, err - } - - // Parse Scheme - var wsScheme string - switch serverURL.Scheme { - case "https": - wsScheme = "wss" - case "http": - wsScheme = "ws" - default: - return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) - } - - // Connect Server WS - wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName) - serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to connect: %v", err) - } - - return &Tunnel{ - name: tunnelName, - target: tunnelTarget, - serverURL: serverURL, - serverConn: serverConn, - localConns: make(map[string]net.Conn), - }, nil - -} - -type Tunnel struct { - name string - target string - serverURL *url.URL - - serverConn *websocket.Conn - localConns map[string]net.Conn - mu sync.RWMutex -} - -func (t *Tunnel) Start() error { - log.Infof("TCP Tunnel active! %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target) - defer t.serverConn.Close() - - // Handle Messages - for { - // Read Message - var msg types.Message - err := t.serverConn.ReadJSON(&msg) - if err != nil { - log.Errorf("Error reading from tunnel: %v", err) - break - } - - switch msg.Type { - case types.MessageTypeData: - localConn, err := t.getLocalConn(msg.StreamID) - if err != nil { - log.Errorf("Failed to get local connection: %v", err) - continue - } - - // Write data to local connection - if _, err := localConn.Write(msg.Data); err != nil { - log.Errorf("Error writing to local connection: %v", err) - localConn.Close() - t.mu.Lock() - delete(t.localConns, msg.StreamID) - t.mu.Unlock() - } - - case types.MessageTypeClose: - t.mu.Lock() - if localConn, exists := t.localConns[msg.StreamID]; exists { - localConn.Close() - delete(t.localConns, msg.StreamID) - } - t.mu.Unlock() - } - } - - return nil -} - -func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { - // Get Cached Connection - t.mu.RLock() - localConn, exists := t.localConns[streamID] - t.mu.RUnlock() - if exists { - return localConn, nil - } - - // Initiate Connection & Cache - localConn, err := net.Dial("tcp", t.target) - if err != nil { - log.Errorf("Failed to connect to %s: %v", t.target, err) - return nil, err - } - t.mu.Lock() - t.localConns[streamID] = localConn - t.mu.Unlock() - - // Start Response Relay & Return Connection - go t.startResponseRelay(streamID, localConn) - return localConn, nil -} - -func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) { - defer func() { - t.mu.Lock() - delete(t.localConns, streamID) - t.mu.Unlock() - localConn.Close() - }() - - buffer := make([]byte, 4096) - for { - n, err := localConn.Read(buffer) - if err != nil { - break - } - - response := types.Message{ - Type: types.MessageTypeData, - StreamID: streamID, - Data: buffer[:n], - } - - if err := t.serverConn.WriteJSON(response); err != nil { - break - } + configDefs := config.GetConfigDefs[config.ClientConfig]() + for _, d := range configDefs { + linkCmd.Flags().String(d.Key, d.Default, d.Description) } } diff --git a/cmd/serve.go b/cmd/serve.go index 0457a60..a9d5a78 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1,30 +1,35 @@ package cmd import ( - "fmt" - "log" - + "reichard.io/conduit/config" "reichard.io/conduit/server" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) -var port string - var serveCmd = &cobra.Command{ Use: "serve", Short: "Start the conduit server", Long: `Start the conduit server to handle incoming tunnel requests`, Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Starting Conduit server on port %s...\n", port) + // Get Server Config + cfg, err := config.GetServerConfig(cmd.Flags()) + if err != nil { + log.Fatal("failed to get server config:", err) + } - srv := server.NewServer() - if err := srv.Start(":" + port); err != nil { - log.Fatal("Failed to start server:", err) + // Start Server + srv := server.NewServer(cfg) + if err := srv.Start(); err != nil { + log.Fatal("failed to start server:", err) } }, } func init() { - serveCmd.Flags().StringVarP(&port, "port", "p", "8080", "Port to run the server on") + configDefs := config.GetConfigDefs[config.ServerConfig]() + for _, d := range configDefs { + serveCmd.Flags().String(d.Key, d.Default, d.Description) + } } diff --git a/config/config.go b/config/config.go index d912156..bb08335 100644 --- a/config/config.go +++ b/config/config.go @@ -1 +1,145 @@ package config + +import ( + "fmt" + "os" + "reflect" + "strings" + + "github.com/spf13/pflag" +) + +type ConfigDef struct { + Key string + Env string + Default string + Description string +} + +type BaseConfig struct { + ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"` + APIKey string `json:"api_key" description:"API Key for the conduit API"` +} + +func (c *BaseConfig) Validate() error { + if c.APIKey == "" { + return fmt.Errorf("api_key is required") + } + return nil +} + +type ServerConfig struct { + BaseConfig + BindAddress string `json:"bind" default:"0.0.0.0:8080" description:"Address the conduit server listens on"` +} + +type ClientConfig struct { + BaseConfig + TunnelName string `json:"name" description:"Tunnel name"` + TunnelTarget string `json:"target" description:"Tunnel target address"` +} + +func (c *ClientConfig) Validate() error { + if err := c.BaseConfig.Validate(); err != nil { + return err + } + if c.TunnelTarget == "" { + return fmt.Errorf("target is required") + } + return nil +} + +func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) { + defs := GetConfigDefs[ServerConfig]() + + cfgValues := make(map[string]string) + for _, def := range defs { + cfgValues[def.Key] = getConfigValue(cmdFlags, def) + } + + cfg := &ServerConfig{ + BaseConfig: BaseConfig{ + ServerAddress: cfgValues["server"], + APIKey: cfgValues["api_key"], + }, + BindAddress: cfgValues["bind"], + } + + return cfg, cfg.Validate() +} + +func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) { + defs := GetConfigDefs[ClientConfig]() + + cfgValues := make(map[string]string) + for _, def := range defs { + cfgValues[def.Key] = getConfigValue(cmdFlags, def) + } + + cfg := &ClientConfig{ + BaseConfig: BaseConfig{ + ServerAddress: cfgValues["address"], + APIKey: cfgValues["api_key"], + }, + TunnelName: cfgValues["name"], + TunnelTarget: cfgValues["target"], + } + + return cfg, cfg.Validate() +} + +func GetConfigDefs[T ServerConfig | ClientConfig]() []ConfigDef { + var defs []ConfigDef + processFields(reflect.TypeFor[T](), &defs) + return defs +} + +func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { + // 1. Get Flags First + if cmdFlags != nil { + if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" { + return val + } + } + + // 2. Environment Variables Next + if envVal := os.Getenv(def.Env); envVal != "" { + return envVal + } + + // 3. Defaults Last + return def.Default +} + +func processFields(t reflect.Type, defs *[]ConfigDef) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Process Embedded (BaseConfig) + if field.Anonymous { + processFields(field.Type, defs) + continue + } + + // Extract Struct Tags + jsonTag := field.Tag.Get("json") + defaultTag := field.Tag.Get("default") + descriptionTag := field.Tag.Get("description") + + // Skip JSON Fields + if jsonTag == "" { + continue + } + + // Get Key & Env + key := strings.Split(jsonTag, ",")[0] + env := "CONDUIT_" + strings.ToUpper(key) + + *defs = append(*defs, ConfigDef{ + Key: key, + Env: env, + Default: defaultTag, + Description: descriptionTag, + }) + } +} diff --git a/server/server.go b/server/server.go index 9039ec9..7d7a9cb 100644 --- a/server/server.go +++ b/server/server.go @@ -13,23 +13,26 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" + "reichard.io/conduit/config" "reichard.io/conduit/types" ) type TunnelConnection struct { *websocket.Conn name string - streams map[string]chan []byte // StreamID -> data channel + streams map[string]chan []byte } type Server struct { tunnels map[string]*TunnelConnection upgrader websocket.Upgrader + cfg *config.ServerConfig mu sync.RWMutex } -func NewServer() *Server { +func NewServer(cfg *config.ServerConfig) *Server { return &Server{ + cfg: cfg, tunnels: make(map[string]*TunnelConnection), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { @@ -39,20 +42,21 @@ func NewServer() *Server { } } -func (s *Server) Start(addr string) error { - // Raw TCP listener instead of http.ListenAndServe - listener, err := net.Listen("tcp", addr) +func (s *Server) Start() error { + // Raw TCP Listener - This is necessary so we can conditionally either relay + // the raw TCP connection, or handle conduit control server API requests. + listener, err := net.Listen("tcp", s.cfg.BindAddress) if err != nil { return err } defer listener.Close() - log.Infof("Conduit server listening on %s", addr) - + // Start Listening + log.Infof("conduit server listening on %s", s.cfg.BindAddress) for { conn, err := listener.Accept() if err != nil { - log.Printf("Error accepting connection: %v", err) + log.Printf("error accepting connection: %v", err) continue } @@ -76,7 +80,7 @@ func (s *Server) extractSubdomain(peakReader io.Reader) string { // Extract Subdomain parts := strings.Split(host, ".") - if len(parts) >= 1 { + if len(parts) > 1 { return parts[0] } @@ -152,9 +156,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne } } -// peakData limits how much we read as we only need to determine -// the host to figure out whether we should proxy or not. -func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) { +func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) { peek := make([]byte, 8192) n, err := conn.Read(peek) if err != nil { @@ -163,13 +165,13 @@ func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) { peekedData := peek[:n] combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn) - return bytes.NewReader(peekedData), combinedReader, nil } func (s *Server) handleRawConnection(conn net.Conn) { defer conn.Close() + // Detect Tunnel peakReader, allReader, _ := s.peekData(conn) if subdomain := s.extractSubdomain(peakReader); subdomain != "" { s.mu.RLock() @@ -177,25 +179,32 @@ func (s *Server) handleRawConnection(conn net.Conn) { s.mu.RUnlock() if exists { - log.Infof("Relaying %s to tunnel", subdomain) - + log.Infof("relaying %s to tunnel", subdomain) s.proxyRawConnection(conn, tunnelConn, allReader) - return } + return } - // Otherwise, handle as control server (recreate HTTP request and use net/http) + // Control Endpoints s.handleAsHTTP(conn, allReader) } func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) { // Create HTTP Request & Writer + w := &connResponseWriter{conn: conn} r, err := http.ReadRequest(bufio.NewReader(allReader)) if err != nil { - _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + w.WriteHeader(http.StatusBadRequest) + return + } + + // Authorize Control Endpoints + apiKey := r.URL.Query().Get("apiKey") + if apiKey != s.cfg.APIKey { + log.Error("unauthorized client") + w.WriteHeader(http.StatusUnauthorized) return } - w := &connResponseWriter{conn: conn} // Handle Control Endpoints switch r.URL.Path { @@ -217,7 +226,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { } if msg.StreamID == "" { - log.Infof("Tunnel %s missing streamID", tunnel.name) + log.Infof("tunnel %s missing streamID", tunnel.name) continue } @@ -228,7 +237,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { s.mu.RLock() streamChan, exists := tunnel.streams[msg.StreamID] if !exists { - log.Infof("Stream %s does not exist", msg.StreamID) + log.Infof("stream %s does not exist", msg.StreamID) s.mu.RUnlock() continue } @@ -236,7 +245,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { select { case streamChan <- msg.Data: case <-time.After(time.Second): - log.Infof("Stream %s channel full, dropping data", msg.StreamID) + log.Warnf("stream %s channel full, dropping data", msg.StreamID) } s.mu.RUnlock() } @@ -261,7 +270,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Upgrade Connection wsConn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { - log.Errorf("WebSocket upgrade failed: %v", err) + log.Errorf("websocket upgrade failed: %v", err) return } @@ -274,7 +283,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { s.mu.Lock() s.tunnels[tunnelName] = tunnel s.mu.Unlock() - log.Infof("Tunnel established: %s", tunnelName) + log.Infof("tunnel established: %s", tunnelName) // Keep connection alive and handle cleanup defer func() { @@ -282,7 +291,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { delete(s.tunnels, tunnelName) s.mu.Unlock() _ = wsConn.Close() - log.Infof("Tunnel closed: %s", tunnelName) + log.Infof("tunnel closed: %s", tunnelName) }() // Handle tunnel messages