conduit/cmd/link.go
2025-09-19 14:59:07 -04:00

137 lines
3.2 KiB
Go

package cmd
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
)
type TunnelRequest struct {
ID string `json:"id"`
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
type TunnelResponse struct {
ID string `json:"id"`
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
var serverAddr string
var linkCmd = &cobra.Command{
Use: "link <vhost_location> <local_port>",
Short: "Create a tunnel link",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
vhostLoc := args[0]
localPort := args[1]
fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort)
if err := startTunnel(vhostLoc, localPort); err != nil {
log.Fatal("Failed to start tunnel:", err)
}
},
}
func init() {
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address")
}
func startTunnel(vhost, localPort string) error {
// Connect to WebSocket
wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return fmt.Errorf("failed to connect: %v", err)
}
defer conn.Close()
fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort)
// Handle incoming requests
for {
var req TunnelRequest
if err := conn.ReadJSON(&req); err != nil {
log.Printf("Error reading request: %v", err)
break
}
go handleTunnelRequest(conn, &req, localPort)
}
return nil
}
func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) {
// Make request to local service
localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL)
httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body))
if err != nil {
sendErrorResponse(conn, req.ID, 500, "Failed to create request")
return
}
// Set headers
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
// Make the request
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
sendErrorResponse(conn, req.ID, 502, "Failed to reach local service")
return
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
sendErrorResponse(conn, req.ID, 500, "Failed to read response")
return
}
// Convert response headers
headers := make(map[string]string)
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
// Send response back
tunnelResp := &TunnelResponse{
ID: req.ID,
StatusCode: resp.StatusCode,
Headers: headers,
Body: body,
}
if err := conn.WriteJSON(tunnelResp); err != nil {
log.Printf("Error sending response: %v", err)
}
}
func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) {
resp := &TunnelResponse{
ID: reqID,
StatusCode: statusCode,
Headers: map[string]string{"Content-Type": "text/plain"},
Body: []byte(message),
}
conn.WriteJSON(resp)
}