This commit is contained in:
2025-09-19 16:35:38 -04:00
parent 84ea3802b8
commit 0223d35b34
4 changed files with 413 additions and 270 deletions

View File

@@ -1,44 +1,34 @@
package cmd
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"net"
"sync"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
)
type TunnelRequest struct {
ID string `json:"id"`
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
}
type TunnelResponse struct {
ID string `json:"id"`
StatusCode int `json:"status_code"`
Headers map[string]string `json:"headers"`
Body []byte `json:"body"`
type TunnelMessage struct {
Type string `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
}
var serverAddr string
var linkCmd = &cobra.Command{
Use: "link <vhost_location> <local_port>",
Use: "link <vhost_location> <host:port>",
Short: "Create a tunnel link",
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
vhostLoc := args[0]
localPort := args[1]
hostPort := args[1]
fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort)
fmt.Printf("Creating TCP tunnel: %s -> %s\n", vhostLoc, hostPort)
if err := startTunnel(vhostLoc, localPort); err != nil {
if err := startTCPTunnel(vhostLoc, hostPort); err != nil {
log.Fatal("Failed to start tunnel:", err)
}
},
@@ -48,8 +38,7 @@ func init() {
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address")
}
func startTunnel(vhost, localPort string) error {
// Connect to WebSocket
func startTCPTunnel(vhost, hostPort string) error {
wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost)
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
@@ -57,80 +46,86 @@ func startTunnel(vhost, localPort string) error {
}
defer conn.Close()
fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort)
fmt.Printf("TCP Tunnel active! %s.example.com -> %s\n", vhost, hostPort)
// Handle incoming requests
// Track active connections
connections := make(map[string]net.Conn)
var connMutex sync.RWMutex
// Handle messages from server
for {
var req TunnelRequest
if err := conn.ReadJSON(&req); err != nil {
log.Printf("Error reading request: %v", err)
var msg TunnelMessage
err := conn.ReadJSON(&msg)
if err != nil {
log.Printf("Error reading from tunnel: %v", err)
break
}
go handleTunnelRequest(conn, &req, localPort)
switch msg.Type {
case "data":
connMutex.RLock()
localConn, exists := connections[msg.StreamID]
connMutex.RUnlock()
if !exists {
// New connection
localConn, err = net.Dial("tcp", hostPort)
if err != nil {
log.Printf("Failed to connect to %s: %v", hostPort, err)
continue
}
connMutex.Lock()
connections[msg.StreamID] = localConn
connMutex.Unlock()
// Start reading from local connection
go func(streamID string, lConn net.Conn) {
defer func() {
connMutex.Lock()
delete(connections, streamID)
connMutex.Unlock()
lConn.Close()
}()
buffer := make([]byte, 4096)
for {
n, err := lConn.Read(buffer)
if err != nil {
break
}
response := TunnelMessage{
Type: "data",
StreamID: streamID,
Data: buffer[:n],
}
if err := conn.WriteJSON(response); err != nil {
break
}
}
}(msg.StreamID, localConn)
}
// Write data to local connection
if _, err := localConn.Write(msg.Data); err != nil {
log.Printf("Error writing to local connection: %v", err)
localConn.Close()
connMutex.Lock()
delete(connections, msg.StreamID)
connMutex.Unlock()
}
case "close":
connMutex.Lock()
if localConn, exists := connections[msg.StreamID]; exists {
localConn.Close()
delete(connections, msg.StreamID)
}
connMutex.Unlock()
}
}
return nil
}
func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) {
// Make request to local service
localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL)
httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body))
if err != nil {
sendErrorResponse(conn, req.ID, 500, "Failed to create request")
return
}
// Set headers
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
// Make the request
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
sendErrorResponse(conn, req.ID, 502, "Failed to reach local service")
return
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
sendErrorResponse(conn, req.ID, 500, "Failed to read response")
return
}
// Convert response headers
headers := make(map[string]string)
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
// Send response back
tunnelResp := &TunnelResponse{
ID: req.ID,
StatusCode: resp.StatusCode,
Headers: headers,
Body: body,
}
if err := conn.WriteJSON(tunnelResp); err != nil {
log.Printf("Error sending response: %v", err)
}
}
func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) {
resp := &TunnelResponse{
ID: reqID,
StatusCode: statusCode,
Headers: map[string]string{"Content-Type": "text/plain"},
Body: []byte(message),
}
conn.WriteJSON(resp)
}