This commit is contained in:
Evan Reichard 2025-09-19 23:53:41 -04:00
parent 4ba4fe381f
commit 5d9684b27e
4 changed files with 147 additions and 143 deletions

View File

@ -40,11 +40,6 @@ func init() {
linkCmd.Flags().StringVarP(&serverAddr, "server", "s", "http://localhost:8080", "Conduit server address")
}
type TunnelConfig struct {
// The conduit server address, e.g. https://conduit.example.com
ServerAddress string `default:"http://localhost:8080"`
}
func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error) {
// Parse Server URL
serverURL, err := url.Parse(serverAddress)
@ -73,6 +68,7 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error)
return &Tunnel{
name: tunnelName,
target: tunnelTarget,
serverURL: serverURL,
serverConn: serverConn,
localConns: make(map[string]net.Conn),
}, nil
@ -80,8 +76,9 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error)
}
type Tunnel struct {
name string
target string
name string
target string
serverURL *url.URL
serverConn *websocket.Conn
localConns map[string]net.Conn
@ -89,7 +86,7 @@ type Tunnel struct {
}
func (t *Tunnel) Start() error {
log.Infof("TCP Tunnel active! %s.example.com -> %s\n", t.name, t.target)
log.Infof("TCP Tunnel active! %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target)
defer t.serverConn.Close()
// Handle Messages
@ -156,17 +153,17 @@ func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) {
return localConn, nil
}
func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) {
func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) {
defer func() {
t.mu.Lock()
delete(t.localConns, streamID)
t.mu.Unlock()
lConn.Close()
localConn.Close()
}()
buffer := make([]byte, 4096)
for {
n, err := lConn.Read(buffer)
n, err := localConn.Read(buffer)
if err != nil {
break
}

6
flake.lock generated
View File

@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1754292888,
"narHash": "sha256-1ziydHSiDuSnaiPzCQh1mRFBsM2d2yRX9I+5OPGEmIE=",
"lastModified": 1758216857,
"narHash": "sha256-h1BW2y7CY4LI9w61R02wPaOYfmYo82FyRqHIwukQ6SY=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ce01daebf8489ba97bd1609d185ea276efdeb121",
"rev": "d2ed99647a4b195f0bcc440f76edfa10aeb3b743",
"type": "github"
},
"original": {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"strings"
@ -23,8 +24,8 @@ type TunnelConnection struct {
type Server struct {
tunnels map[string]*TunnelConnection
mu sync.RWMutex
upgrader websocket.Upgrader
mu sync.RWMutex
}
func NewServer() *Server {
@ -59,11 +60,21 @@ func (s *Server) Start(addr string) error {
}
}
func (s *Server) extractSubdomain(host string) string {
func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Read Request
req, err := http.ReadRequest(bufio.NewReader(peakReader))
if err != nil {
return ""
}
defer req.Body.Close()
// Extract Host
host := req.Host
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) >= 1 {
return parts[0]
@ -72,57 +83,26 @@ func (s *Server) extractSubdomain(host string) string {
return ""
}
func (s *Server) getStatus(conn net.Conn) {
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
s.mu.RLock()
count := len(s.tunnels)
s.mu.RUnlock()
response := fmt.Sprintf(
"HTTP/1.1 200 OK\r\n"+
"Content-Type: application/json\r\n"+
"Content-Length: %d\r\n\r\n"+
`{"tunnels": %d}`,
len(fmt.Sprintf(`{"tunnels": %d}`, count)), count)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
conn.Write([]byte(response))
response := fmt.Sprintf(`{"tunnels": %d}`, count)
_, _ = w.Write([]byte(response))
}
func (s *Server) extractHostFromHTTP(data []byte) string {
// Simple HTTP header parsing
lines := strings.Split(string(data), "\r\n")
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "host:") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[1])
}
}
}
return ""
}
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, initialData []byte) {
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) {
defer clientConn.Close()
// Generate a unique stream ID for this connection
// Create Identifiers
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
// Send initial data with stream ID
msg := types.Message{
Type: "data",
StreamID: streamID,
Data: initialData,
}
if err := tunnelConn.WriteJSON(msg); err != nil {
log.Errorf("Error sending initial data: %v", err)
return
}
// Create a channel for this stream's responses
responseChan := make(chan []byte, 100)
// Register this stream
// Register Stream
s.mu.Lock()
if tunnelConn.streams == nil {
tunnelConn.streams = make(map[string]chan []byte)
@ -130,43 +110,41 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
tunnelConn.streams[streamID] = responseChan
s.mu.Unlock()
// Clean up when done
// Clean Up
defer func() {
s.mu.Lock()
delete(tunnelConn.streams, streamID)
close(responseChan)
s.mu.Unlock()
// Send close message
// Send Close
closeMsg := types.Message{
Type: "close",
Type: types.MessageTypeClose,
StreamID: streamID,
}
tunnelConn.WriteJSON(closeMsg)
_ = tunnelConn.WriteJSON(closeMsg)
}()
// Handle client -> tunnel
// Read & Send Chunks
go func() {
buffer := make([]byte, 4096)
for {
n, err := clientConn.Read(buffer)
n, err := dataReader.Read(buffer)
if err != nil {
return
}
msg := types.Message{
Type: "data",
if err := tunnelConn.WriteJSON(types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}
if err := tunnelConn.WriteJSON(msg); err != nil {
}); err != nil {
return
}
}
}()
// Return Client Response Data
// Return Response Data
for data := range responseChan {
if _, err := clientConn.Write(data); err != nil {
break
@ -174,54 +152,59 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
}
}
// peakData limits how much we read as we only need to determine
// the host to figure out whether we should proxy or not.
func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
return nil, nil, err
}
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
// Read enough to get the Host header
buffer := make([]byte, 4096)
n, err := conn.Read(buffer)
if err != nil {
return
}
// Extract host
host := s.extractHostFromHTTP(buffer[:n])
subdomain := s.extractSubdomain(host)
// If we have a registered tunnel for this subdomain, relay it
if subdomain != "" {
peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock()
tunnelConn, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if exists {
log.Infof("Relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, buffer[:n])
s.proxyRawConnection(conn, tunnelConn, allReader)
return
}
}
// Otherwise, handle as control server (recreate HTTP request and use net/http)
s.handleAsHTTP(conn, buffer[:n])
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) {
// Create a fake HTTP request from the raw data
reader := bufio.NewReader(bytes.NewReader(initialData))
req, err := http.ReadRequest(reader)
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer
r, err := http.ReadRequest(bufio.NewReader(allReader))
if err != nil {
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
return
}
w := &connResponseWriter{conn: conn}
// Handle Control Endpoints
switch req.URL.Path {
switch r.URL.Path {
case "/_conduit/tunnel":
s.createTunnel(conn, req)
s.createTunnel(w, r)
case "/_conduit/status":
s.getStatus(conn)
s.getStatus(w, r)
default:
_, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n"))
w.WriteHeader(http.StatusNotFound)
}
}
@ -230,39 +213,53 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
var msg types.Message
err := tunnel.ReadJSON(&msg)
if err != nil {
break
return
}
// Route message to appropriate stream
if msg.Type == "data" && msg.StreamID != "" {
if msg.StreamID == "" {
log.Infof("Tunnel %s missing streamID", tunnel.name)
continue
}
switch msg.Type {
case types.MessageTypeClose:
return
case types.MessageTypeData:
s.mu.RLock()
if streamChan, exists := tunnel.streams[msg.StreamID]; exists {
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
}
streamChan, exists := tunnel.streams[msg.StreamID]
if !exists {
log.Infof("Stream %s does not exist", msg.StreamID)
s.mu.RUnlock()
continue
}
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
}
s.mu.RUnlock()
}
}
}
func (s *Server) createTunnel(conn net.Conn, req *http.Request) {
func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Get Tunnel Name
tunnelName := req.URL.Query().Get("tunnelName")
tunnelName := r.URL.Query().Get("tunnelName")
if tunnelName == "" {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter"))
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("Missing tunnelName parameter"))
return
}
// Validate Unique
if _, exists := s.tunnels[tunnelName]; exists {
conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered"))
w.WriteHeader(http.StatusConflict)
_, _ = w.Write([]byte("Tunnel already registered"))
return
}
// Upgrade Connection
wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil)
wsConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err)
return
@ -284,48 +281,10 @@ func (s *Server) createTunnel(conn net.Conn, req *http.Request) {
s.mu.Lock()
delete(s.tunnels, tunnelName)
s.mu.Unlock()
wsConn.Close()
_ = wsConn.Close()
log.Infof("Tunnel closed: %s", tunnelName)
}()
// Handle tunnel messages
s.handleTunnelMessages(tunnel)
}
type rawResponseWriter struct {
conn net.Conn
header http.Header
}
func (f *rawResponseWriter) Header() http.Header {
if f.header == nil {
f.header = make(http.Header)
}
return f.header
}
func (f *rawResponseWriter) Write(data []byte) (int, error) {
return f.conn.Write(data)
}
func (f *rawResponseWriter) WriteHeader(statusCode int) {
// Write Status
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
_, _ = f.conn.Write([]byte(status))
// Write Headers
for key, values := range f.header {
for _, value := range values {
_, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value)
}
}
// End Headers
_, _ = f.conn.Write([]byte("\r\n"))
}
func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Return Raw Connection & ReadWriter
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil
}

48
server/writer.go Normal file
View File

@ -0,0 +1,48 @@
package server
import (
"bufio"
"fmt"
"net"
"net/http"
)
var _ http.ResponseWriter = (*connResponseWriter)(nil)
type connResponseWriter struct {
conn net.Conn
header http.Header
}
func (f *connResponseWriter) Header() http.Header {
if f.header == nil {
f.header = make(http.Header)
}
return f.header
}
func (f *connResponseWriter) Write(data []byte) (int, error) {
return f.conn.Write(data)
}
func (f *connResponseWriter) WriteHeader(statusCode int) {
// Write Status
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
_, _ = f.conn.Write([]byte(status))
// Write Headers
for key, values := range f.header {
for _, value := range values {
_, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value)
}
}
// End Headers
_, _ = f.conn.Write([]byte("\r\n"))
}
func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Return Raw Connection & ReadWriter
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil
}