From 84ea3802b84f631bf7baa3c391eb03b758ec2a71 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Fri, 19 Sep 2025 14:59:07 -0400 Subject: [PATCH] initial commit --- cmd/link.go | 136 +++++++++++++++++++++++++++ cmd/root.go | 26 ++++++ cmd/serve.go | 30 ++++++ flake.lock | 61 ++++++++++++ flake.nix | 26 ++++++ go.mod | 10 ++ go.sum | 12 +++ main.go | 7 ++ server/server.go | 237 +++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 545 insertions(+) create mode 100644 cmd/link.go create mode 100644 cmd/root.go create mode 100644 cmd/serve.go create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 server/server.go diff --git a/cmd/link.go b/cmd/link.go new file mode 100644 index 0000000..7d547e4 --- /dev/null +++ b/cmd/link.go @@ -0,0 +1,136 @@ +package cmd + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + + "github.com/gorilla/websocket" + "github.com/spf13/cobra" +) + +type TunnelRequest struct { + ID string `json:"id"` + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +type TunnelResponse struct { + ID string `json:"id"` + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +var serverAddr string + +var linkCmd = &cobra.Command{ + Use: "link ", + Short: "Create a tunnel link", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + vhostLoc := args[0] + localPort := args[1] + + fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort) + + if err := startTunnel(vhostLoc, localPort); err != nil { + log.Fatal("Failed to start tunnel:", err) + } + }, +} + +func init() { + linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address") +} + +func startTunnel(vhost, localPort string) error { + // Connect to WebSocket + wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + return fmt.Errorf("failed to connect: %v", err) + } + defer conn.Close() + + fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort) + + // Handle incoming requests + for { + var req TunnelRequest + if err := conn.ReadJSON(&req); err != nil { + log.Printf("Error reading request: %v", err) + break + } + + go handleTunnelRequest(conn, &req, localPort) + } + + return nil +} + +func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) { + // Make request to local service + localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL) + + httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body)) + if err != nil { + sendErrorResponse(conn, req.ID, 500, "Failed to create request") + return + } + + // Set headers + for k, v := range req.Headers { + httpReq.Header.Set(k, v) + } + + // Make the request + client := &http.Client{} + resp, err := client.Do(httpReq) + if err != nil { + sendErrorResponse(conn, req.ID, 502, "Failed to reach local service") + return + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + sendErrorResponse(conn, req.ID, 500, "Failed to read response") + return + } + + // Convert response headers + headers := make(map[string]string) + for k, v := range resp.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + + // Send response back + tunnelResp := &TunnelResponse{ + ID: req.ID, + StatusCode: resp.StatusCode, + Headers: headers, + Body: body, + } + + if err := conn.WriteJSON(tunnelResp); err != nil { + log.Printf("Error sending response: %v", err) + } +} + +func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) { + resp := &TunnelResponse{ + ID: reqID, + StatusCode: statusCode, + Headers: map[string]string{"Content-Type": "text/plain"}, + Body: []byte(message), + } + conn.WriteJSON(resp) +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..7706746 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var rootCmd = &cobra.Command{ + Use: "conduit", + Short: "A tunneling service similar to ngrok", + Long: `Conduit allows you to expose local services through secure tunnels`, +} + +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func init() { + rootCmd.AddCommand(serveCmd) + rootCmd.AddCommand(linkCmd) +} diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 0000000..0457a60 --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "fmt" + "log" + + "reichard.io/conduit/server" + + "github.com/spf13/cobra" +) + +var port string + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start the conduit server", + Long: `Start the conduit server to handle incoming tunnel requests`, + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("Starting Conduit server on port %s...\n", port) + + srv := server.NewServer() + if err := srv.Start(":" + port); err != nil { + log.Fatal("Failed to start server:", err) + } + }, +} + +func init() { + serveCmd.Flags().StringVarP(&port, "port", "p", "8080", "Port to run the server on") +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..cbf5f41 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1754292888, + "narHash": "sha256-1ziydHSiDuSnaiPzCQh1mRFBsM2d2yRX9I+5OPGEmIE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ce01daebf8489ba97bd1609d185ea276efdeb121", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..38b37b8 --- /dev/null +++ b/flake.nix @@ -0,0 +1,26 @@ +{ + description = "Development Environment"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + devShells.default = pkgs.mkShell { + packages = with pkgs; [ + go + golangci-lint + ]; + shellHook = '' + export PATH=$PATH:~/go/bin + ''; + }; + } + ); +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..51ef378 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module reichard.io/conduit + +go 1.24.4 + +require ( + github.com/gorilla/websocket v1.5.3 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/cobra v1.10.1 // indirect + github.com/spf13/pflag v1.0.9 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a3d862c --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..0fcc792 --- /dev/null +++ b/main.go @@ -0,0 +1,7 @@ +package main + +import "reichard.io/conduit/cmd" + +func main() { + cmd.Execute() +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..89856ab --- /dev/null +++ b/server/server.go @@ -0,0 +1,237 @@ +package server + +import ( + "fmt" + "io" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type TunnelConnection struct { + conn *websocket.Conn + vhost string + responses map[string]chan *TunnelResponse + mu sync.RWMutex +} + +type TunnelRequest struct { + ID string `json:"id"` + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +type TunnelResponse struct { + ID string `json:"id"` + StatusCode int `json:"status_code"` + Headers map[string]string `json:"headers"` + Body []byte `json:"body"` +} + +type Server struct { + tunnels map[string]*TunnelConnection + mu sync.RWMutex + upgrader websocket.Upgrader +} + +func NewServer() *Server { + return &Server{ + tunnels: make(map[string]*TunnelConnection), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + } +} + +func (s *Server) Start(addr string) error { + http.HandleFunc("/", s.handleRequest) + http.HandleFunc("/_conduit/tunnel", s.handleTunnel) + http.HandleFunc("/_conduit/status", s.handleStatus) + + log.Printf("Conduit server listening on %s", addr) + return http.ListenAndServe(addr, nil) +} + +func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { + host := r.Host + subdomain := s.extractSubdomain(host) + + if subdomain == "" { + http.Error(w, "Invalid host", http.StatusBadRequest) + return + } + + s.mu.RLock() + tunnel, exists := s.tunnels[subdomain] + s.mu.RUnlock() + + if !exists { + http.Error(w, "Tunnel not found", http.StatusNotFound) + return + } + + // Create unique request ID + reqID := fmt.Sprintf("%d", time.Now().UnixNano()) + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Convert headers + headers := make(map[string]string) + for k, v := range r.Header { + if len(v) > 0 { + headers[k] = v[0] + } + } + + // Create response channel for this request + respChan := make(chan *TunnelResponse, 1) + tunnel.mu.Lock() + tunnel.responses[reqID] = respChan + tunnel.mu.Unlock() + + // Clean up response channel when done + defer func() { + tunnel.mu.Lock() + delete(tunnel.responses, reqID) + tunnel.mu.Unlock() + }() + + tunnelReq := &TunnelRequest{ + ID: reqID, + Method: r.Method, + URL: r.URL.String(), + Headers: headers, + Body: body, + } + + // Send request to tunnel + if err := tunnel.conn.WriteJSON(tunnelReq); err != nil { + log.Printf("Error sending request to tunnel: %v", err) + http.Error(w, "Tunnel communication error", http.StatusServiceUnavailable) + return + } + + // Wait for response + select { + case resp := <-respChan: + // Write response headers + for k, v := range resp.Headers { + w.Header().Set(k, v) + } + w.WriteHeader(resp.StatusCode) + w.Write(resp.Body) + + case <-time.After(30 * time.Second): + http.Error(w, "Tunnel timeout", http.StatusGatewayTimeout) + } +} + +func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) { + vhost := r.URL.Query().Get("vhost") + if vhost == "" { + http.Error(w, "Missing vhost parameter", http.StatusBadRequest) + return + } + + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + + tunnel := &TunnelConnection{ + conn: conn, + vhost: vhost, + responses: make(map[string]chan *TunnelResponse), + } + + s.mu.Lock() + s.tunnels[vhost] = tunnel + s.mu.Unlock() + + log.Printf("Tunnel established: %s", vhost) + + // Handle tunnel communication + s.handleTunnelConnection(tunnel) +} + +func (s *Server) handleTunnelConnection(tunnel *TunnelConnection) { + defer func() { + s.mu.Lock() + delete(s.tunnels, tunnel.vhost) + s.mu.Unlock() + tunnel.conn.Close() + log.Printf("Tunnel closed: %s", tunnel.vhost) + }() + + // Handle incoming responses from client + for { + var resp TunnelResponse + if err := tunnel.conn.ReadJSON(&resp); err != nil { + log.Printf("Error reading response from tunnel %s: %v", tunnel.vhost, err) + return + } + + // Find the response channel for this request + tunnel.mu.RLock() + respChan, exists := tunnel.responses[resp.ID] + tunnel.mu.RUnlock() + + if exists { + select { + case respChan <- &resp: + // Response delivered + default: + log.Printf("Response channel full for request %s", resp.ID) + } + } else { + log.Printf("No response channel found for request %s", resp.ID) + } + } +} + +func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { + s.mu.RLock() + defer s.mu.RUnlock() + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"tunnels": %d, "active": [`, len(s.tunnels)) + + first := true + for vhost := range s.tunnels { + if !first { + fmt.Fprint(w, ",") + } + fmt.Fprintf(w, `{"vhost":"%s"}`, vhost) + first = false + } + + fmt.Fprint(w, "]}") +} + +func (s *Server) extractSubdomain(host string) string { + if idx := strings.Index(host, ":"); idx != -1 { + host = host[:idx] + } + + parts := strings.Split(host, ".") + if len(parts) >= 1 { + return parts[0] + } + + return "" +}