works
This commit is contained in:
175
cmd/link.go
175
cmd/link.go
@@ -1,44 +1,34 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type TunnelRequest struct {
|
||||
ID string `json:"id"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
}
|
||||
|
||||
type TunnelResponse struct {
|
||||
ID string `json:"id"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
Body []byte `json:"body"`
|
||||
type TunnelMessage struct {
|
||||
Type string `json:"type"`
|
||||
StreamID string `json:"stream_id"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var serverAddr string
|
||||
|
||||
var linkCmd = &cobra.Command{
|
||||
Use: "link <vhost_location> <local_port>",
|
||||
Use: "link <vhost_location> <host:port>",
|
||||
Short: "Create a tunnel link",
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
vhostLoc := args[0]
|
||||
localPort := args[1]
|
||||
hostPort := args[1]
|
||||
|
||||
fmt.Printf("Creating tunnel: %s -> localhost:%s\n", vhostLoc, localPort)
|
||||
fmt.Printf("Creating TCP tunnel: %s -> %s\n", vhostLoc, hostPort)
|
||||
|
||||
if err := startTunnel(vhostLoc, localPort); err != nil {
|
||||
if err := startTCPTunnel(vhostLoc, hostPort); err != nil {
|
||||
log.Fatal("Failed to start tunnel:", err)
|
||||
}
|
||||
},
|
||||
@@ -48,8 +38,7 @@ func init() {
|
||||
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "localhost:8080", "Conduit server address")
|
||||
}
|
||||
|
||||
func startTunnel(vhost, localPort string) error {
|
||||
// Connect to WebSocket
|
||||
func startTCPTunnel(vhost, hostPort string) error {
|
||||
wsURL := fmt.Sprintf("ws://%s/_conduit/tunnel?vhost=%s", serverAddr, vhost)
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
@@ -57,80 +46,86 @@ func startTunnel(vhost, localPort string) error {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
fmt.Printf("Tunnel active! %s.example.com -> localhost:%s\n", vhost, localPort)
|
||||
fmt.Printf("TCP Tunnel active! %s.example.com -> %s\n", vhost, hostPort)
|
||||
|
||||
// Handle incoming requests
|
||||
// Track active connections
|
||||
connections := make(map[string]net.Conn)
|
||||
var connMutex sync.RWMutex
|
||||
|
||||
// Handle messages from server
|
||||
for {
|
||||
var req TunnelRequest
|
||||
if err := conn.ReadJSON(&req); err != nil {
|
||||
log.Printf("Error reading request: %v", err)
|
||||
var msg TunnelMessage
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
log.Printf("Error reading from tunnel: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
go handleTunnelRequest(conn, &req, localPort)
|
||||
switch msg.Type {
|
||||
case "data":
|
||||
connMutex.RLock()
|
||||
localConn, exists := connections[msg.StreamID]
|
||||
connMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// New connection
|
||||
localConn, err = net.Dial("tcp", hostPort)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to %s: %v", hostPort, err)
|
||||
continue
|
||||
}
|
||||
|
||||
connMutex.Lock()
|
||||
connections[msg.StreamID] = localConn
|
||||
connMutex.Unlock()
|
||||
|
||||
// Start reading from local connection
|
||||
go func(streamID string, lConn net.Conn) {
|
||||
defer func() {
|
||||
connMutex.Lock()
|
||||
delete(connections, streamID)
|
||||
connMutex.Unlock()
|
||||
lConn.Close()
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := lConn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
response := TunnelMessage{
|
||||
Type: "data",
|
||||
StreamID: streamID,
|
||||
Data: buffer[:n],
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(response); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}(msg.StreamID, localConn)
|
||||
}
|
||||
|
||||
// Write data to local connection
|
||||
if _, err := localConn.Write(msg.Data); err != nil {
|
||||
log.Printf("Error writing to local connection: %v", err)
|
||||
localConn.Close()
|
||||
connMutex.Lock()
|
||||
delete(connections, msg.StreamID)
|
||||
connMutex.Unlock()
|
||||
}
|
||||
|
||||
case "close":
|
||||
connMutex.Lock()
|
||||
if localConn, exists := connections[msg.StreamID]; exists {
|
||||
localConn.Close()
|
||||
delete(connections, msg.StreamID)
|
||||
}
|
||||
connMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleTunnelRequest(conn *websocket.Conn, req *TunnelRequest, localPort string) {
|
||||
// Make request to local service
|
||||
localURL := fmt.Sprintf("http://localhost:%s%s", localPort, req.URL)
|
||||
|
||||
httpReq, err := http.NewRequest(req.Method, localURL, bytes.NewReader(req.Body))
|
||||
if err != nil {
|
||||
sendErrorResponse(conn, req.ID, 500, "Failed to create request")
|
||||
return
|
||||
}
|
||||
|
||||
// Set headers
|
||||
for k, v := range req.Headers {
|
||||
httpReq.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
sendErrorResponse(conn, req.ID, 502, "Failed to reach local service")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
sendErrorResponse(conn, req.ID, 500, "Failed to read response")
|
||||
return
|
||||
}
|
||||
|
||||
// Convert response headers
|
||||
headers := make(map[string]string)
|
||||
for k, v := range resp.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Send response back
|
||||
tunnelResp := &TunnelResponse{
|
||||
ID: req.ID,
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(tunnelResp); err != nil {
|
||||
log.Printf("Error sending response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendErrorResponse(conn *websocket.Conn, reqID string, statusCode int, message string) {
|
||||
resp := &TunnelResponse{
|
||||
ID: reqID,
|
||||
StatusCode: statusCode,
|
||||
Headers: map[string]string{"Content-Type": "text/plain"},
|
||||
Body: []byte(message),
|
||||
}
|
||||
conn.WriteJSON(resp)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user