This commit is contained in:
Evan Reichard 2025-09-19 21:42:28 -04:00
parent 7f8fb011ce
commit 4ba4fe381f
5 changed files with 63 additions and 103 deletions

View File

@ -9,14 +9,9 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"reichard.io/conduit/types"
) )
type TunnelMessage struct {
Type string `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
}
var serverAddr string var serverAddr string
var linkCmd = &cobra.Command{ var linkCmd = &cobra.Command{
@ -69,7 +64,7 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error)
} }
// Connect Server WS // Connect Server WS
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?vhost=%s", wsScheme, serverURL.Host, tunnelName) wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName)
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err) return nil, fmt.Errorf("failed to connect: %v", err)
@ -100,7 +95,7 @@ func (t *Tunnel) Start() error {
// Handle Messages // Handle Messages
for { for {
// Read Message // Read Message
var msg TunnelMessage var msg types.Message
err := t.serverConn.ReadJSON(&msg) err := t.serverConn.ReadJSON(&msg)
if err != nil { if err != nil {
log.Errorf("Error reading from tunnel: %v", err) log.Errorf("Error reading from tunnel: %v", err)
@ -108,7 +103,7 @@ func (t *Tunnel) Start() error {
} }
switch msg.Type { switch msg.Type {
case "data": case types.MessageTypeData:
localConn, err := t.getLocalConn(msg.StreamID) localConn, err := t.getLocalConn(msg.StreamID)
if err != nil { if err != nil {
log.Errorf("Failed to get local connection: %v", err) log.Errorf("Failed to get local connection: %v", err)
@ -124,7 +119,7 @@ func (t *Tunnel) Start() error {
t.mu.Unlock() t.mu.Unlock()
} }
case "close": case types.MessageTypeClose:
t.mu.Lock() t.mu.Lock()
if localConn, exists := t.localConns[msg.StreamID]; exists { if localConn, exists := t.localConns[msg.StreamID]; exists {
localConn.Close() localConn.Close()
@ -176,8 +171,8 @@ func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) {
break break
} }
response := TunnelMessage{ response := types.Message{
Type: "data", Type: types.MessageTypeData,
StreamID: streamID, StreamID: streamID,
Data: buffer[:n], Data: buffer[:n],
} }

1
config/config.go Normal file
View File

@ -0,0 +1 @@
package config

View File

@ -12,20 +12,15 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/conduit/types"
) )
type TunnelConnection struct { type TunnelConnection struct {
*websocket.Conn *websocket.Conn
vhost string name string
streams map[string]chan []byte // StreamID -> data channel streams map[string]chan []byte // StreamID -> data channel
} }
type TunnelMessage struct {
Type string `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
}
type Server struct { type Server struct {
tunnels map[string]*TunnelConnection tunnels map[string]*TunnelConnection
mu sync.RWMutex mu sync.RWMutex
@ -77,7 +72,7 @@ func (s *Server) extractSubdomain(host string) string {
return "" return ""
} }
func (s *Server) handleStatus(conn net.Conn) { func (s *Server) getStatus(conn net.Conn) {
s.mu.RLock() s.mu.RLock()
count := len(s.tunnels) count := len(s.tunnels)
s.mu.RUnlock() s.mu.RUnlock()
@ -113,7 +108,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
// Send initial data with stream ID // Send initial data with stream ID
msg := TunnelMessage{ msg := types.Message{
Type: "data", Type: "data",
StreamID: streamID, StreamID: streamID,
Data: initialData, Data: initialData,
@ -143,7 +138,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
s.mu.Unlock() s.mu.Unlock()
// Send close message // Send close message
closeMsg := TunnelMessage{ closeMsg := types.Message{
Type: "close", Type: "close",
StreamID: streamID, StreamID: streamID,
} }
@ -159,7 +154,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
return return
} }
msg := TunnelMessage{ msg := types.Message{
Type: "data", Type: "data",
StreamID: streamID, StreamID: streamID,
Data: buffer[:n], Data: buffer[:n],
@ -215,25 +210,24 @@ func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) {
reader := bufio.NewReader(bytes.NewReader(initialData)) reader := bufio.NewReader(bytes.NewReader(initialData))
req, err := http.ReadRequest(reader) req, err := http.ReadRequest(reader)
if err != nil { if err != nil {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) _, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
return return
} }
// Handle Control Endpoints // Handle Control Endpoints
switch req.URL.Path { switch req.URL.Path {
case "/_conduit/tunnel": case "/_conduit/tunnel":
s.handleTunnelUpgrade(conn, req) s.createTunnel(conn, req)
return
case "/_conduit/status": case "/_conduit/status":
s.handleStatus(conn) s.getStatus(conn)
default: default:
conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n")) _, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n"))
} }
} }
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) { func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
for { for {
var msg TunnelMessage var msg types.Message
err := tunnel.ReadJSON(&msg) err := tunnel.ReadJSON(&msg)
if err != nil { if err != nil {
break break
@ -253,84 +247,85 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
} }
} }
} }
func (s *Server) handleTunnelUpgrade(conn net.Conn, req *http.Request) { func (s *Server) createTunnel(conn net.Conn, req *http.Request) {
vhost := req.URL.Query().Get("vhost") // Get Tunnel Name
if vhost == "" { tunnelName := req.URL.Query().Get("tunnelName")
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing vhost parameter")) if tunnelName == "" {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter"))
return return
} }
// Create a fake ResponseWriter that writes to our raw connection // Validate Unique
fakeWriter := &fakeResponseWriter{conn: conn} if _, exists := s.tunnels[tunnelName]; exists {
conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered"))
return
}
// Use the upgrader // Upgrade Connection
wsConn, err := s.upgrader.Upgrade(fakeWriter, req, nil) wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil)
if err != nil { if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err) log.Errorf("WebSocket upgrade failed: %v", err)
return return
} }
// Create TunnelConnection // Create & Cache TunnelConnection
tunnel := &TunnelConnection{ tunnel := &TunnelConnection{
Conn: wsConn, Conn: wsConn,
vhost: vhost, name: tunnelName,
streams: make(map[string]chan []byte), streams: make(map[string]chan []byte),
} }
s.mu.Lock() s.mu.Lock()
s.tunnels[vhost] = tunnel s.tunnels[tunnelName] = tunnel
s.mu.Unlock() s.mu.Unlock()
log.Infof("Tunnel established: %s", tunnelName)
log.Infof("Tunnel established: %s", vhost)
// Keep connection alive and handle cleanup // Keep connection alive and handle cleanup
defer func() { defer func() {
s.mu.Lock() s.mu.Lock()
delete(s.tunnels, vhost) delete(s.tunnels, tunnelName)
s.mu.Unlock() s.mu.Unlock()
wsConn.Close() wsConn.Close()
log.Infof("Tunnel closed: %s", vhost) log.Infof("Tunnel closed: %s", tunnelName)
}() }()
// Handle tunnel messages // Handle tunnel messages
s.handleTunnelMessages(tunnel) s.handleTunnelMessages(tunnel)
} }
type fakeResponseWriter struct { type rawResponseWriter struct {
conn net.Conn conn net.Conn
header http.Header header http.Header
} }
func (f *fakeResponseWriter) Header() http.Header { func (f *rawResponseWriter) Header() http.Header {
if f.header == nil { if f.header == nil {
f.header = make(http.Header) f.header = make(http.Header)
} }
return f.header return f.header
} }
func (f *fakeResponseWriter) Write(data []byte) (int, error) { func (f *rawResponseWriter) Write(data []byte) (int, error) {
return f.conn.Write(data) return f.conn.Write(data)
} }
func (f *fakeResponseWriter) WriteHeader(statusCode int) { func (f *rawResponseWriter) WriteHeader(statusCode int) {
// Write HTTP status line // Write Status
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
f.conn.Write([]byte(status)) _, _ = f.conn.Write([]byte(status))
// Write headers // Write Headers
for key, values := range f.header { for key, values := range f.header {
for _, value := range values { for _, value := range values {
f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value))) _, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value)
} }
} }
// End headers // End Headers
f.conn.Write([]byte("\r\n")) _, _ = f.conn.Write([]byte("\r\n"))
} }
// Implement http.Hijacker func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (f *fakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Return Raw Connection & ReadWriter
// Return the raw connection and create a ReadWriter for it
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil return f.conn, rw, nil
} }

View File

@ -1,45 +0,0 @@
package server
import (
"bufio"
"fmt"
"net"
"net/http"
)
type tunnelWriter struct {
conn net.Conn
header http.Header
}
func (f *tunnelWriter) Header() http.Header {
if f.header == nil {
f.header = make(http.Header)
}
return f.header
}
func (f *tunnelWriter) Write(data []byte) (int, error) {
return f.conn.Write(data)
}
func (f *tunnelWriter) WriteHeader(statusCode int) {
// Write HTTP status line
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
f.conn.Write([]byte(status))
// Write headers
for key, values := range f.header {
for _, value := range values {
f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value)))
}
}
// End headers
f.conn.Write([]byte("\r\n"))
}
func (f *tunnelWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil
}

14
types/message.go Normal file
View File

@ -0,0 +1,14 @@
package types
type MessageType string
const (
MessageTypeData MessageType = "data"
MessageTypeClose MessageType = "close"
)
type Message struct {
Type MessageType `json:"type"`
StreamID string `json:"stream_id"`
Data []byte `json:"data,omitempty"`
}