137 lines
3.2 KiB
Go
137 lines
3.2 KiB
Go
package cmd
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
type TunnelRequest struct {
|
|
ID string `json:"id"`
|
|
Method string `json:"method"`
|
|
URL string `json:"url"`
|
|
Headers map[string]string `json:"headers"`
|
|
Body []byte `json:"body"`
|
|
}
|
|
|
|
type TunnelResponse struct {
|
|
ID string `json:"id"`
|
|
StatusCode int `json:"status_code"`
|
|
Headers map[string]string `json:"headers"`
|
|
Body []byte `json:"body"`
|
|
}
|
|
|
|
var serverAddr string
|
|
|
|
var linkCmd = &cobra.Command{
|
|
Use: "link <vhost_location> <local_port>",
|
|
Short: "Create a tunnel link",
|
|
Args: cobra.ExactArgs(2),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
vhostLoc := args[0]
|
|
localPort := args[1]
|
|
|
|
fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort)
|
|
|
|
if err := startTunnel(vhostLoc, localPort); err != nil {
|
|
log.Fatal("Failed to start tunnel:", err)
|
|
}
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address")
|
|
}
|
|
|
|
func startTunnel(vhost, localPort string) error {
|
|
// Connect to WebSocket
|
|
wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost)
|
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort)
|
|
|
|
// Handle incoming requests
|
|
for {
|
|
var req TunnelRequest
|
|
if err := conn.ReadJSON(&req); err != nil {
|
|
log.Printf("Error reading request: %v", err)
|
|
break
|
|
}
|
|
|
|
go handleTunnelRequest(conn, &req, localPort)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) {
|
|
// Make request to local service
|
|
localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL)
|
|
|
|
httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body))
|
|
if err != nil {
|
|
sendErrorResponse(conn, req.ID, 500, "Failed to create request")
|
|
return
|
|
}
|
|
|
|
// Set headers
|
|
for k, v := range req.Headers {
|
|
httpReq.Header.Set(k, v)
|
|
}
|
|
|
|
// Make the request
|
|
client := &http.Client{}
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
sendErrorResponse(conn, req.ID, 502, "Failed to reach local service")
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
sendErrorResponse(conn, req.ID, 500, "Failed to read response")
|
|
return
|
|
}
|
|
|
|
// Convert response headers
|
|
headers := make(map[string]string)
|
|
for k, v := range resp.Header {
|
|
if len(v) > 0 {
|
|
headers[k] = v[0]
|
|
}
|
|
}
|
|
|
|
// Send response back
|
|
tunnelResp := &TunnelResponse{
|
|
ID: req.ID,
|
|
StatusCode: resp.StatusCode,
|
|
Headers: headers,
|
|
Body: body,
|
|
}
|
|
|
|
if err := conn.WriteJSON(tunnelResp); err != nil {
|
|
log.Printf("Error sending response: %v", err)
|
|
}
|
|
}
|
|
|
|
func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) {
|
|
resp := &TunnelResponse{
|
|
ID: reqID,
|
|
StatusCode: statusCode,
|
|
Headers: map[string]string{"Content-Type": "text/plain"},
|
|
Body: []byte(message),
|
|
}
|
|
conn.WriteJSON(resp)
|
|
}
|