commit 5bb9052fa4fc723bc75545147ebc371901d23855 Author: Evan Reichard Date: Fri Sep 19 14:59:07 2025 -0400 initial commit diff --git a/.golangci.toml b/.golangci.toml new file mode 100644 index 0000000..bc546a4 --- /dev/null +++ b/.golangci.toml @@ -0,0 +1,6 @@ +#:schema https://golangci-lint.run/jsonschema/golangci.jsonschema.json +version = "2" + +[[linters.exclusions.rules]] +linters = [ "errcheck" ] +source = "^\\s*defer\\s+" diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..a9786d2 --- /dev/null +++ b/client/client.go @@ -0,0 +1,159 @@ +package client + +import ( + "fmt" + "net" + "net/url" + "sync" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + "reichard.io/conduit/config" + "reichard.io/conduit/types" +) + +func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) { + // Parse Server URL + serverURL, err := url.Parse(cfg.ServerAddress) + if err != nil { + return nil, err + } + + // Parse Scheme + var wsScheme string + switch serverURL.Scheme { + case "https": + wsScheme = "wss" + case "http": + wsScheme = "ws" + default: + return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) + } + + // Create Tunnel Name + if cfg.TunnelName == "" { + cfg.TunnelName = generateTunnelName() + log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName) + } + + // Connect Server WS + wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey) + serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } + + return &Tunnel{ + name: cfg.TunnelName, + target: cfg.TunnelTarget, + serverURL: serverURL, + serverConn: serverConn, + localConns: make(map[string]net.Conn), + }, nil + +} + +type Tunnel struct { + name string + target string + serverURL *url.URL + + serverConn *websocket.Conn + localConns map[string]net.Conn + mu sync.RWMutex +} + +func (t *Tunnel) Start() error { + log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target) + defer t.serverConn.Close() + + // Handle Messages + for { + // Read Message + var msg types.Message + err := t.serverConn.ReadJSON(&msg) + if err != nil { + log.Errorf("error reading from tunnel: %v", err) + break + } + + switch msg.Type { + case types.MessageTypeData: + localConn, err := t.getLocalConn(msg.StreamID) + if err != nil { + log.Errorf("failed to get local connection: %v", err) + continue + } + + // Write data to local connection + if _, err := localConn.Write(msg.Data); err != nil { + log.Errorf("error writing to local connection: %v", err) + localConn.Close() + t.mu.Lock() + delete(t.localConns, msg.StreamID) + t.mu.Unlock() + } + + case types.MessageTypeClose: + t.mu.Lock() + if localConn, exists := t.localConns[msg.StreamID]; exists { + localConn.Close() + delete(t.localConns, msg.StreamID) + } + t.mu.Unlock() + } + } + + return nil +} + +func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) { + // Get Cached Connection + t.mu.RLock() + localConn, exists := t.localConns[streamID] + t.mu.RUnlock() + if exists { + return localConn, nil + } + + // Initiate Connection & Cache + localConn, err := net.Dial("tcp", t.target) + if err != nil { + log.Errorf("failed to connect to %s: %v", t.target, err) + return nil, err + } + t.mu.Lock() + t.localConns[streamID] = localConn + t.mu.Unlock() + + // Start Response Relay & Return Connection + go t.startResponseRelay(streamID, localConn) + return localConn, nil +} + +func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) { + defer func() { + t.mu.Lock() + delete(t.localConns, streamID) + t.mu.Unlock() + localConn.Close() + }() + + buffer := make([]byte, 4096) + for { + n, err := localConn.Read(buffer) + if err != nil { + break + } + + response := types.Message{ + Type: types.MessageTypeData, + StreamID: streamID, + Data: buffer[:n], + } + + if err := t.serverConn.WriteJSON(response); err != nil { + break + } + } +} diff --git a/client/name.go b/client/name.go new file mode 100644 index 0000000..97aba3e --- /dev/null +++ b/client/name.go @@ -0,0 +1,23 @@ +package client + +import ( + "fmt" + "math/rand" +) + +var colors = []string{ + "red", "blue", "green", "yellow", "purple", "orange", + "pink", "brown", "black", "white", "gray", "cyan", +} + +var animals = []string{ + "cat", "dog", "bird", "fish", "lion", "tiger", + "bear", "wolf", "fox", "deer", "rabbit", "mouse", +} + +func generateTunnelName() string { + color := colors[rand.Intn(len(colors))] + animal := animals[rand.Intn(len(animals))] + number := rand.Intn(900) + 100 + return fmt.Sprintf("%s-%s-%d", color, animal, number) +} diff --git a/cmd/link.go b/cmd/link.go new file mode 100644 index 0000000..53802c8 --- /dev/null +++ b/cmd/link.go @@ -0,0 +1,39 @@ +package cmd + +import ( + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "reichard.io/conduit/client" + "reichard.io/conduit/config" +) + +var linkCmd = &cobra.Command{ + Use: "link ", + Short: "Create a conduit tunnel", + Run: func(cmd *cobra.Command, args []string) { + // Get Client Config + cfg, err := config.GetClientConfig(cmd.Flags()) + if err != nil { + log.Fatal("failed to get client config:", err) + } + + // Create Tunnel + tunnel, err := client.NewTunnel(cfg) + if err != nil { + log.Fatal("failed to create tunnel:", err) + } + + // Start Tunnel + log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget) + if err := tunnel.Start(); err != nil { + log.Fatal("failed to start tunnel:", err) + } + }, +} + +func init() { + configDefs := config.GetConfigDefs[config.ClientConfig]() + for _, d := range configDefs { + linkCmd.Flags().String(d.Key, d.Default, d.Description) + } +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..7706746 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "conduit", + Short: "A tunneling service similar to ngrok", + Long: `Conduit allows you to expose local services through secure tunnels`, +} + +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func init() { + rootCmd.AddCommand(serveCmd) + rootCmd.AddCommand(linkCmd) +} diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 0000000..c63a4ff --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,40 @@ +package cmd + +import ( + "reichard.io/conduit/config" + "reichard.io/conduit/server" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start the conduit server", + Long: `Start the conduit server to handle incoming tunnel requests`, + Run: func(cmd *cobra.Command, args []string) { + // Get Server Config + cfg, err := config.GetServerConfig(cmd.Flags()) + if err != nil { + log.Fatal("failed to get server config:", err) + } + + // Create Server + srv, err := server.NewServer(cfg) + if err != nil { + log.Fatal("failed to create server:", err) + } + + // Start Server + if err := srv.Start(); err != nil { + log.Fatal("failed to start server:", err) + } + }, +} + +func init() { + configDefs := config.GetConfigDefs[config.ServerConfig]() + for _, d := range configDefs { + serveCmd.Flags().String(d.Key, d.Default, d.Description) + } +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..58e6abf --- /dev/null +++ b/config/config.go @@ -0,0 +1,153 @@ +package config + +import ( + "errors" + "fmt" + "net/url" + "os" + "reflect" + "strings" + + "github.com/spf13/pflag" +) + +type ConfigDef struct { + Key string + Env string + Default string + Description string +} + +type BaseConfig struct { + ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"` + APIKey string `json:"api_key" description:"API Key for the conduit API"` +} + +func (c *BaseConfig) Validate() error { + if c.APIKey == "" { + return errors.New("api_key is required") + } + if c.ServerAddress == "" { + return errors.New("server is required") + } + if _, err := url.Parse(c.ServerAddress); err != nil { + return fmt.Errorf("server is invalid: %w", err) + } + return nil +} + +type ServerConfig struct { + BaseConfig + BindAddress string `json:"bind" default:"0.0.0.0:8080" description:"Address the conduit server listens on"` +} + +type ClientConfig struct { + BaseConfig + TunnelName string `json:"name" description:"Tunnel name"` + TunnelTarget string `json:"target" description:"Tunnel target address"` +} + +func (c *ClientConfig) Validate() error { + if err := c.BaseConfig.Validate(); err != nil { + return err + } + if c.TunnelTarget == "" { + return fmt.Errorf("target is required") + } + return nil +} + +func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) { + defs := GetConfigDefs[ServerConfig]() + + cfgValues := make(map[string]string) + for _, def := range defs { + cfgValues[def.Key] = getConfigValue(cmdFlags, def) + } + + cfg := &ServerConfig{ + BaseConfig: BaseConfig{ + ServerAddress: cfgValues["server"], + APIKey: cfgValues["api_key"], + }, + BindAddress: cfgValues["bind"], + } + + return cfg, cfg.Validate() +} + +func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) { + defs := GetConfigDefs[ClientConfig]() + + cfgValues := make(map[string]string) + for _, def := range defs { + cfgValues[def.Key] = getConfigValue(cmdFlags, def) + } + + cfg := &ClientConfig{ + BaseConfig: BaseConfig{ + ServerAddress: cfgValues["server"], + APIKey: cfgValues["api_key"], + }, + TunnelName: cfgValues["name"], + TunnelTarget: cfgValues["target"], + } + + return cfg, cfg.Validate() +} + +func GetConfigDefs[T ServerConfig | ClientConfig]() []ConfigDef { + var defs []ConfigDef + processFields(reflect.TypeFor[T](), &defs) + return defs +} + +func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { + // 1. Get Flags First + if cmdFlags != nil { + if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" { + return val + } + } + + // 2. Environment Variables Next + if envVal := os.Getenv(def.Env); envVal != "" { + return envVal + } + + // 3. Defaults Last + return def.Default +} + +func processFields(t reflect.Type, defs *[]ConfigDef) { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Process Embedded (BaseConfig) + if field.Anonymous { + processFields(field.Type, defs) + continue + } + + // Extract Struct Tags + jsonTag := field.Tag.Get("json") + defaultTag := field.Tag.Get("default") + descriptionTag := field.Tag.Get("description") + + // Skip JSON Fields + if jsonTag == "" { + continue + } + + // Get Key & Env + key := strings.Split(jsonTag, ",")[0] + env := "CONDUIT_" + strings.ToUpper(key) + + *defs = append(*defs, ConfigDef{ + Key: key, + Env: env, + Default: defaultTag, + Description: descriptionTag, + }) + } +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..fa2eda9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1758216857, + "narHash": "sha256-h1BW2y7CY4LI9w61R02wPaOYfmYo82FyRqHIwukQ6SY=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "d2ed99647a4b195f0bcc440f76edfa10aeb3b743", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..38b37b8 --- /dev/null +++ b/flake.nix @@ -0,0 +1,26 @@ +{ + description = "Development Environment"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + devShells.default = pkgs.mkShell { + packages = with pkgs; [ + go + golangci-lint + ]; + shellHook = '' + export PATH=$PATH:~/go/bin + ''; + }; + } + ); +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..29ae61b --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module reichard.io/conduit + +go 1.24.4 + +require ( + github.com/gorilla/websocket v1.5.3 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/spf13/cobra v1.10.1 // indirect + github.com/spf13/pflag v1.0.9 // indirect + golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7cb5bd1 --- /dev/null +++ b/go.sum @@ -0,0 +1,22 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..0fcc792 --- /dev/null +++ b/main.go @@ -0,0 +1,7 @@ +package main + +import "reichard.io/conduit/cmd" + +func main() { + cmd.Execute() +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..aa4dd24 --- /dev/null +++ b/server/server.go @@ -0,0 +1,301 @@ +package server + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + "reichard.io/conduit/config" + "reichard.io/conduit/types" +) + +type TunnelConnection struct { + *websocket.Conn + name string + streams map[string]chan []byte +} + +type Server struct { + host string + cfg *config.ServerConfig + mu sync.RWMutex + + upgrader websocket.Upgrader + tunnels map[string]*TunnelConnection +} + +func NewServer(cfg *config.ServerConfig) (*Server, error) { + serverURL, err := url.Parse(cfg.ServerAddress) + if err != nil { + return nil, fmt.Errorf("failed to parse server address: %v", err) + } else if serverURL.Host == "" { + return nil, errors.New("invalid server address") + } + + return &Server{ + cfg: cfg, + host: serverURL.Host, + tunnels: make(map[string]*TunnelConnection), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + }, nil +} + +func (s *Server) Start() error { + // Raw TCP Listener - This is necessary so we can conditionally either relay + // the raw TCP connection, or handle conduit control server API requests. + listener, err := net.Listen("tcp", s.cfg.BindAddress) + if err != nil { + return err + } + defer listener.Close() + + // Start Listening + log.Infof("conduit server listening on %s", s.cfg.BindAddress) + for { + conn, err := listener.Accept() + if err != nil { + log.Printf("error accepting connection: %v", err) + continue + } + + go s.handleRawConnection(conn) + } +} + +func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) { + s.mu.RLock() + count := len(s.tunnels) + s.mu.RUnlock() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + response := fmt.Sprintf(`{"tunnels": %d}`, count) + _, _ = w.Write([]byte(response)) +} + +func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) { + defer clientConn.Close() + + // Create Identifiers + streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) + responseChan := make(chan []byte, 100) + + // Register Stream + s.mu.Lock() + if tunnelConn.streams == nil { + tunnelConn.streams = make(map[string]chan []byte) + } + tunnelConn.streams[streamID] = responseChan + s.mu.Unlock() + + // Clean Up + defer func() { + s.mu.Lock() + delete(tunnelConn.streams, streamID) + close(responseChan) + s.mu.Unlock() + + // Send Close + closeMsg := types.Message{ + Type: types.MessageTypeClose, + StreamID: streamID, + } + _ = tunnelConn.WriteJSON(closeMsg) + }() + + // Read & Send Chunks + go func() { + buffer := make([]byte, 4096) + for { + n, err := dataReader.Read(buffer) + if err != nil { + return + } + + if err := tunnelConn.WriteJSON(types.Message{ + Type: types.MessageTypeData, + StreamID: streamID, + Data: buffer[:n], + }); err != nil { + return + } + } + }() + + // Return Response Data + for data := range responseChan { + if _, err := clientConn.Write(data); err != nil { + break + } + } +} + +func (s *Server) handleRawConnection(conn net.Conn) { + defer conn.Close() + + // Capture Consumed Data - When determining where to route the request, we + // have to read the host headers. This requires reading from the buffer, so + // if we later decide to tunnel the TCP connection we need to reconstruct the + // data from the buffer. + var capturedData bytes.Buffer + teeReader := io.TeeReader(conn, &capturedData) + bufReader := bufio.NewReader(teeReader) + + // Create HTTP Request & Writer + w := &connResponseWriter{conn: conn} + r, err := http.ReadRequest(bufReader) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Validate Host + if !strings.Contains(r.Host, s.host) { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unknown host: %s", r.Host) + return + } + + // Extract Subdomain + subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") + if strings.Count(subdomain, ".") != 0 { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) + return + } + + // Handle Control Endpoints + if subdomain == "" { + s.handleAsHTTP(w, r) + return + } + + // Handle Tunnels + s.mu.RLock() + tunnelConn, exists := s.tunnels[subdomain] + s.mu.RUnlock() + if exists { + log.Infof("relaying %s to tunnel", subdomain) + + // Reconstruct Data & Proxy Connection + allReader := io.MultiReader(&capturedData, r.Body) + s.proxyRawConnection(conn, tunnelConn, allReader) + } +} + +func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { + // Authorize Control Endpoints + apiKey := r.URL.Query().Get("apiKey") + if apiKey != s.cfg.APIKey { + log.Error("unauthorized client") + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Handle Control Endpoints + switch r.URL.Path { + case "/_conduit/tunnel": + s.createTunnel(w, r) + case "/_conduit/status": + s.getStatus(w, r) + default: + w.WriteHeader(http.StatusNotFound) + } +} + +func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { + for { + var msg types.Message + err := tunnel.ReadJSON(&msg) + if err != nil { + return + } + + if msg.StreamID == "" { + log.Infof("tunnel %s missing streamID", tunnel.name) + continue + } + + switch msg.Type { + case types.MessageTypeClose: + return + case types.MessageTypeData: + s.mu.RLock() + streamChan, exists := tunnel.streams[msg.StreamID] + if !exists { + log.Infof("stream %s does not exist", msg.StreamID) + s.mu.RUnlock() + continue + } + + select { + case streamChan <- msg.Data: + case <-time.After(time.Second): + log.Warnf("stream %s channel full, dropping data", msg.StreamID) + } + s.mu.RUnlock() + } + } +} +func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { + // Get Tunnel Name + tunnelName := r.URL.Query().Get("tunnelName") + if tunnelName == "" { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("Missing tunnelName parameter")) + return + } + + // Validate Unique + if _, exists := s.tunnels[tunnelName]; exists { + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte("Tunnel already registered")) + return + } + + // Upgrade Connection + wsConn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Errorf("websocket upgrade failed: %v", err) + return + } + + // Create & Cache TunnelConnection + tunnel := &TunnelConnection{ + Conn: wsConn, + name: tunnelName, + streams: make(map[string]chan []byte), + } + s.mu.Lock() + s.tunnels[tunnelName] = tunnel + s.mu.Unlock() + log.Infof("tunnel established: %s", tunnelName) + + // Keep connection alive and handle cleanup + defer func() { + s.mu.Lock() + delete(s.tunnels, tunnelName) + s.mu.Unlock() + _ = wsConn.Close() + log.Infof("tunnel closed: %s", tunnelName) + }() + + // Handle tunnel messages + s.handleTunnelMessages(tunnel) +} diff --git a/server/writer.go b/server/writer.go new file mode 100644 index 0000000..601cc40 --- /dev/null +++ b/server/writer.go @@ -0,0 +1,48 @@ +package server + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +var _ http.ResponseWriter = (*connResponseWriter)(nil) + +type connResponseWriter struct { + conn net.Conn + header http.Header +} + +func (f *connResponseWriter) Header() http.Header { + if f.header == nil { + f.header = make(http.Header) + } + return f.header +} + +func (f *connResponseWriter) Write(data []byte) (int, error) { + return f.conn.Write(data) +} + +func (f *connResponseWriter) WriteHeader(statusCode int) { + // Write Status + status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + _, _ = f.conn.Write([]byte(status)) + + // Write Headers + for key, values := range f.header { + for _, value := range values { + _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value) + } + } + + // End Headers + _, _ = f.conn.Write([]byte("\r\n")) +} + +func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Return Raw Connection & ReadWriter + rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) + return f.conn, rw, nil +} diff --git a/types/message.go b/types/message.go new file mode 100644 index 0000000..3d54160 --- /dev/null +++ b/types/message.go @@ -0,0 +1,14 @@ +package types + +type MessageType string + +const ( + MessageTypeData MessageType = "data" + MessageTypeClose MessageType = "close" +) + +type Message struct { + Type MessageType `json:"type"` + StreamID string `json:"stream_id"` + Data []byte `json:"data,omitempty"` +}