Compare commits

..

9 Commits
0.0.1 ... main

Author SHA1 Message Date
0722e5f032 chore: tunnel recorder & slight refactor
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-27 17:49:59 -04:00
20c1388cf4 chore: better source tracking
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-23 09:24:09 -04:00
0333680a2b chore: move to sync map 2025-09-23 09:04:06 -04:00
de23b3e815 log error
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-22 23:26:58 -04:00
2e73689762 http vs tcp tunnel
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-22 23:04:15 -04:00
d5de31eda7 fix infinite close
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-22 15:30:54 -04:00
b8714e52de wip 2
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-21 18:41:47 -04:00
f5741ef60b wip 1 2025-09-21 13:14:45 -04:00
31add1984b fix env vars
All checks were successful
continuous-integration/drone/push Build is passing
2025-09-20 21:29:40 -04:00
21 changed files with 938 additions and 329 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
cover.html cover.html
.DS_Store

View File

@ -1,159 +0,0 @@
package client
import (
"fmt"
"net"
"net/url"
"sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/config"
"reichard.io/conduit/types"
)
func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) {
// Parse Server URL
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, err
}
// Parse Scheme
var wsScheme string
switch serverURL.Scheme {
case "https":
wsScheme = "wss"
case "http":
wsScheme = "ws"
default:
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
}
// Create Tunnel Name
if cfg.TunnelName == "" {
cfg.TunnelName = generateTunnelName()
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
}
// Connect Server WS
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
}
return &Tunnel{
name: cfg.TunnelName,
target: cfg.TunnelTarget,
serverURL: serverURL,
serverConn: serverConn,
localConns: make(map[string]net.Conn),
}, nil
}
type Tunnel struct {
name string
target string
serverURL *url.URL
serverConn *websocket.Conn
localConns map[string]net.Conn
mu sync.RWMutex
}
func (t *Tunnel) Start() error {
log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target)
defer t.serverConn.Close()
// Handle Messages
for {
// Read Message
var msg types.Message
err := t.serverConn.ReadJSON(&msg)
if err != nil {
log.Errorf("error reading from tunnel: %v", err)
break
}
switch msg.Type {
case types.MessageTypeData:
localConn, err := t.getLocalConn(msg.StreamID)
if err != nil {
log.Errorf("failed to get local connection: %v", err)
continue
}
// Write data to local connection
if _, err := localConn.Write(msg.Data); err != nil {
log.Errorf("error writing to local connection: %v", err)
localConn.Close()
t.mu.Lock()
delete(t.localConns, msg.StreamID)
t.mu.Unlock()
}
case types.MessageTypeClose:
t.mu.Lock()
if localConn, exists := t.localConns[msg.StreamID]; exists {
localConn.Close()
delete(t.localConns, msg.StreamID)
}
t.mu.Unlock()
}
}
return nil
}
func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) {
// Get Cached Connection
t.mu.RLock()
localConn, exists := t.localConns[streamID]
t.mu.RUnlock()
if exists {
return localConn, nil
}
// Initiate Connection & Cache
localConn, err := net.Dial("tcp", t.target)
if err != nil {
log.Errorf("failed to connect to %s: %v", t.target, err)
return nil, err
}
t.mu.Lock()
t.localConns[streamID] = localConn
t.mu.Unlock()
// Start Response Relay & Return Connection
go t.startResponseRelay(streamID, localConn)
return localConn, nil
}
func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) {
defer func() {
t.mu.Lock()
delete(t.localConns, streamID)
t.mu.Unlock()
localConn.Close()
}()
buffer := make([]byte, 4096)
for {
n, err := localConn.Read(buffer)
if err != nil {
break
}
response := types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}
if err := t.serverConn.WriteJSON(response); err != nil {
break
}
}
}

View File

@ -1,6 +1,8 @@
package cmd package cmd
import ( import (
"context"
"reichard.io/conduit/config" "reichard.io/conduit/config"
"reichard.io/conduit/server" "reichard.io/conduit/server"
@ -19,8 +21,11 @@ var serveCmd = &cobra.Command{
log.Fatal("failed to get server config:", err) log.Fatal("failed to get server config:", err)
} }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create Server // Create Server
srv, err := server.NewServer(cfg) srv, err := server.NewServer(ctx, cfg)
if err != nil { if err != nil {
log.Fatal("failed to create server:", err) log.Fatal("failed to create server:", err)
} }

View File

@ -1,10 +1,13 @@
package cmd package cmd
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"reichard.io/conduit/client"
"reichard.io/conduit/config" "reichard.io/conduit/config"
"reichard.io/conduit/store"
"reichard.io/conduit/tunnel"
) )
var tunnelCmd = &cobra.Command{ var tunnelCmd = &cobra.Command{
@ -17,17 +20,22 @@ var tunnelCmd = &cobra.Command{
log.Fatal("failed to get client config:", err) log.Fatal("failed to get client config:", err)
} }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create Forwarder
tunnelForwarder, err := tunnel.NewForwarder(cfg.TunnelTarget, store.NewTunnelStore(100))
if err != nil {
log.Fatal("failed to create tunnel forwarder:", err)
}
go tunnelForwarder.Start(ctx)
// Create Tunnel // Create Tunnel
tunnel, err := client.NewTunnel(cfg) tunnel, err := tunnel.NewClientTunnel(cfg, tunnelForwarder)
if err != nil { if err != nil {
log.Fatal("failed to create tunnel:", err) log.Fatal("failed to create tunnel:", err)
} }
tunnel.Start(ctx)
// Start Tunnel
log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget)
if err := tunnel.Start(); err != nil {
log.Fatal("failed to start tunnel:", err)
}
}, },
} }

View File

@ -23,6 +23,8 @@ type ConfigDef struct {
type BaseConfig struct { type BaseConfig struct {
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"` ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
APIKey string `json:"api_key" description:"API Key for the conduit API"` APIKey string `json:"api_key" description:"API Key for the conduit API"`
LogLevel string `json:"log_level" default:"info" description:"Log level"`
LogFormat string `json:"log_format" default:"text" description:"Log format - text or json"`
} }
func (c *BaseConfig) Validate() error { func (c *BaseConfig) Validate() error {
@ -35,6 +37,9 @@ func (c *BaseConfig) Validate() error {
if _, err := url.Parse(c.ServerAddress); err != nil { if _, err := url.Parse(c.ServerAddress); err != nil {
return fmt.Errorf("server is invalid: %w", err) return fmt.Errorf("server is invalid: %w", err)
} }
if c.LogFormat != "text" && c.LogFormat != "json" {
return fmt.Errorf("log format must be 'text' or 'json'")
}
return nil return nil
} }
@ -68,13 +73,13 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
} }
cfg := &ServerConfig{ cfg := &ServerConfig{
BaseConfig: BaseConfig{ BaseConfig: getBaseConfig(cfgValues),
ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"],
},
BindAddress: cfgValues["bind"], BindAddress: cfgValues["bind"],
} }
// Initialize Logger
initLogger(cfg.BaseConfig)
return cfg, cfg.Validate() return cfg, cfg.Validate()
} }
@ -87,14 +92,14 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
} }
cfg := &ClientConfig{ cfg := &ClientConfig{
BaseConfig: BaseConfig{ BaseConfig: getBaseConfig(cfgValues),
ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"],
},
TunnelName: cfgValues["name"], TunnelName: cfgValues["name"],
TunnelTarget: cfgValues["target"], TunnelTarget: cfgValues["target"],
} }
// Initialize Logger
initLogger(cfg.BaseConfig)
return cfg, cfg.Validate() return cfg, cfg.Validate()
} }
@ -108,10 +113,19 @@ func GetVersion() string {
return version return version
} }
func getBaseConfig(cfgValues map[string]string) BaseConfig {
return BaseConfig{
ServerAddress: cfgValues["server"],
APIKey: cfgValues["api_key"],
LogLevel: cfgValues["log_level"],
LogFormat: cfgValues["log_format"],
}
}
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
// 1. Get Flags First // 1. Get Flags First
if cmdFlags != nil { if cmdFlags != nil {
if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" { if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" && val != def.Default {
return val return val
} }
} }

61
config/logging.go Normal file
View File

@ -0,0 +1,61 @@
package config
import (
"fmt"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
func initLogger(cfg BaseConfig) {
// Parse Log Level
logLevel, err := log.ParseLevel(cfg.LogLevel)
if err != nil {
logLevel = log.InfoLevel
}
log.SetLevel(logLevel)
// Create Log Formatter
var logFormatter log.Formatter
switch cfg.LogFormat {
case "json":
log.SetReportCaller(true)
logFormatter = &log.JSONFormatter{
TimestampFormat: time.RFC3339,
CallerPrettyfier: prettyCaller,
}
case "text":
logFormatter = &log.TextFormatter{
TimestampFormat: time.RFC3339,
FullTimestamp: true,
}
}
log.SetFormatter(&utcFormatter{logFormatter})
}
func prettyCaller(f *runtime.Frame) (function string, file string) {
purgePrefix := "reichard.io/conduit/"
pathName := strings.Replace(f.Func.Name(), purgePrefix, "", 1)
parts := strings.Split(pathName, ".")
filepath, line := f.Func.FileLine(f.PC)
splitFilePath := strings.Split(filepath, "/")
fileName := fmt.Sprintf("%s/%s@%d", parts[0], splitFilePath[len(splitFilePath)-1], line)
functionName := strings.Replace(pathName, parts[0]+".", "", 1)
return functionName, fileName
}
type utcFormatter struct {
log.Formatter
}
func (cf utcFormatter) Format(e *log.Entry) ([]byte, error) {
e.Time = e.Time.UTC()
return cf.Formatter.Format(e)
}

1
go.mod
View File

@ -3,6 +3,7 @@ module reichard.io/conduit
go 1.24.4 go 1.24.4
require ( require (
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect github.com/gorilla/websocket v1.5.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect

2
go.sum
View File

@ -1,6 +1,8 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=

51
pkg/maps/map.go Normal file
View File

@ -0,0 +1,51 @@
package maps
import (
"iter"
"sync"
)
type Map[K comparable, V any] struct {
items map[K]V
mu sync.RWMutex
}
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{items: make(map[K]V)}
}
func (m *Map[K, V]) Get(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
v, ok := m.items[key]
return v, ok
}
func (m *Map[K, V]) Set(key K, value V) {
m.mu.Lock()
defer m.mu.Unlock()
m.items[key] = value
}
func (m *Map[K, V]) Delete(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
}
func (m *Map[K, V]) HasKey(key K) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.items[key]
return ok
}
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
for k, v := range m.items {
if !yield(k, v) {
return
}
}
}
}

View File

@ -7,25 +7,25 @@ import (
"net/http" "net/http"
) )
var _ http.ResponseWriter = (*connResponseWriter)(nil) var _ http.ResponseWriter = (*rawHTTPResponseWriter)(nil)
type connResponseWriter struct { type rawHTTPResponseWriter struct {
conn net.Conn conn net.Conn
header http.Header header http.Header
} }
func (f *connResponseWriter) Header() http.Header { func (f *rawHTTPResponseWriter) Header() http.Header {
if f.header == nil { if f.header == nil {
f.header = make(http.Header) f.header = make(http.Header)
} }
return f.header return f.header
} }
func (f *connResponseWriter) Write(data []byte) (int, error) { func (f *rawHTTPResponseWriter) Write(data []byte) (int, error) {
return f.conn.Write(data) return f.conn.Write(data)
} }
func (f *connResponseWriter) WriteHeader(statusCode int) { func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) {
// Write Status // Write Status
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
_, _ = f.conn.Write([]byte(status)) _, _ = f.conn.Write([]byte(status))
@ -41,7 +41,7 @@ func (f *connResponseWriter) WriteHeader(statusCode int) {
_, _ = f.conn.Write([]byte("\r\n")) _, _ = f.conn.Write([]byte("\r\n"))
} }
func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (f *rawHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Return Raw Connection & ReadWriter // Return Raw Connection & ReadWriter
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn)) rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
return f.conn, rw, nil return f.conn, rw, nil

View File

@ -0,0 +1,30 @@
package server
import (
"bytes"
"io"
"net"
)
var _ io.ReadWriteCloser = (*reconstructedConn)(nil)
// reconstructedConn wraps a net.Conn and overrides Read to handle captured data.
type reconstructedConn struct {
net.Conn
reader io.Reader
}
// Read reads from the reconstructed reader (captured data + original conn).
func (rc *reconstructedConn) Read(p []byte) (n int, err error) {
return rc.reader.Read(p)
}
// newReconstructedConn creates a reconstructed connection that replays captured data
// before reading from the original connection.
func newReconstructedConn(conn net.Conn, capturedData *bytes.Buffer) net.Conn {
allReader := io.MultiReader(capturedData, conn)
return &reconstructedConn{
Conn: conn,
reader: allReader,
}
}

View File

@ -3,6 +3,7 @@ package server
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -11,13 +12,13 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/conduit/config" "reichard.io/conduit/config"
"reichard.io/conduit/types" "reichard.io/conduit/pkg/maps"
"reichard.io/conduit/tunnel"
) )
type InfoResponse struct { type InfoResponse struct {
@ -30,22 +31,16 @@ type TunnelInfo struct {
Target string `json:"target"` Target string `json:"target"`
} }
type TunnelConnection struct {
*websocket.Conn
name string
streams map[string]chan []byte
}
type Server struct { type Server struct {
ctx context.Context
host string host string
cfg *config.ServerConfig cfg *config.ServerConfig
mu sync.RWMutex
upgrader websocket.Upgrader upgrader websocket.Upgrader
tunnels map[string]*TunnelConnection tunnels *maps.Map[string, *tunnel.Tunnel]
} }
func NewServer(cfg *config.ServerConfig) (*Server, error) { func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
serverURL, err := url.Parse(cfg.ServerAddress) serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse server address: %v", err) return nil, fmt.Errorf("failed to parse server address: %v", err)
@ -54,9 +49,10 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) {
} }
return &Server{ return &Server{
ctx: ctx,
cfg: cfg, cfg: cfg,
host: serverURL.Host, host: serverURL.Host,
tunnels: make(map[string]*TunnelConnection), tunnels: maps.New[string, *tunnel.Tunnel](),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
@ -79,7 +75,7 @@ func (s *Server) Start() error {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Printf("error accepting connection: %v", err) log.WithError(err).Error("error accepting connection")
continue continue
} }
@ -90,14 +86,12 @@ func (s *Server) Start() error {
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
// Get Tunnels // Get Tunnels
var allTunnels []TunnelInfo var allTunnels []TunnelInfo
s.mu.RLock() for t, c := range s.tunnels.Entries() {
for t, c := range s.tunnels {
allTunnels = append(allTunnels, TunnelInfo{ allTunnels = append(allTunnels, TunnelInfo{
Name: t, Name: t,
Target: c.RemoteAddr().String(), Target: c.Source(),
}) })
} }
s.mu.RUnlock()
// Create Response // Create Response
d, err := json.MarshalIndent(InfoResponse{ d, err := json.MarshalIndent(InfoResponse{
@ -105,72 +99,17 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
Version: config.GetVersion(), Version: config.GetVersion(),
}, "", " ") }, "", " ")
if err != nil { if err != nil {
log.WithError(err).Error("failed to marshal info")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
// Send Response // Send Response
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(d) _, _ = w.Write(d)
} }
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) {
defer clientConn.Close()
// Create Identifiers
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
responseChan := make(chan []byte, 100)
// Register Stream
s.mu.Lock()
if tunnelConn.streams == nil {
tunnelConn.streams = make(map[string]chan []byte)
}
tunnelConn.streams[streamID] = responseChan
s.mu.Unlock()
// Clean Up
defer func() {
s.mu.Lock()
delete(tunnelConn.streams, streamID)
close(responseChan)
s.mu.Unlock()
// Send Close
closeMsg := types.Message{
Type: types.MessageTypeClose,
StreamID: streamID,
}
_ = tunnelConn.WriteJSON(closeMsg)
}()
// Read & Send Chunks
go func() {
buffer := make([]byte, 4096)
for {
n, err := dataReader.Read(buffer)
if err != nil {
return
}
if err := tunnelConn.WriteJSON(types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n],
}); err != nil {
return
}
}
}()
// Return Response Data
for data := range responseChan {
if _, err := clientConn.Write(data); err != nil {
break
}
}
}
func (s *Server) handleRawConnection(conn net.Conn) { func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close() defer conn.Close()
@ -183,7 +122,7 @@ func (s *Server) handleRawConnection(conn net.Conn) {
bufReader := bufio.NewReader(teeReader) bufReader := bufio.NewReader(teeReader)
// Create HTTP Request & Writer // Create HTTP Request & Writer
w := &connResponseWriter{conn: conn} w := &rawHTTPResponseWriter{conn: conn}
r, err := http.ReadRequest(bufReader) r, err := http.ReadRequest(bufReader)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
@ -199,30 +138,49 @@ func (s *Server) handleRawConnection(conn net.Conn) {
} }
// Extract Subdomain // Extract Subdomain
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(subdomain, ".") != 0 { if strings.Count(tunnelName, ".") != 0 {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return return
} }
// Get True Host
remoteHost := conn.RemoteAddr().String()
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
remoteHost = xff
}
r.RemoteAddr = remoteHost
// Handle Control Endpoints // Handle Control Endpoints
if subdomain == "" { if tunnelName == "" {
s.handleAsHTTP(w, r) s.handleAsHTTP(w, r)
return return
} }
// Handle Tunnels // Handle Tunnels
s.mu.RLock() conduitTunnel, exists := s.tunnels.Get(tunnelName)
tunnelConn, exists := s.tunnels[subdomain] if !exists {
s.mu.RUnlock() w.WriteHeader(http.StatusNotFound)
if exists { _, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
log.Infof("relaying %s to tunnel", subdomain) return
// Reconstruct Data & Proxy Connection
allReader := io.MultiReader(&capturedData, r.Body)
s.proxyRawConnection(conn, tunnelConn, allReader)
} }
// Create Stream
reconstructedConn := newReconstructedConn(conn, &capturedData)
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr)
// Add Stream
if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprintf(w, "failed to add stream: %v", err)
log.WithError(err).Error("failed to add stream")
return
}
// Start Stream
conduitTunnel.StartStream(tunnelStream, streamID)
} }
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
@ -245,40 +203,6 @@ func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
for {
var msg types.Message
err := tunnel.ReadJSON(&msg)
if err != nil {
return
}
if msg.StreamID == "" {
log.Infof("tunnel %s missing streamID", tunnel.name)
continue
}
switch msg.Type {
case types.MessageTypeClose:
return
case types.MessageTypeData:
s.mu.RLock()
streamChan, exists := tunnel.streams[msg.StreamID]
if !exists {
log.Infof("stream %s does not exist", msg.StreamID)
s.mu.RUnlock()
continue
}
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Warnf("stream %s channel full, dropping data", msg.StreamID)
}
s.mu.RUnlock()
}
}
}
func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Get Tunnel Name // Get Tunnel Name
tunnelName := r.URL.Query().Get("tunnelName") tunnelName := r.URL.Query().Get("tunnelName")
@ -289,7 +213,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
} }
// Validate Unique // Validate Unique
if _, exists := s.tunnels[tunnelName]; exists { if _, exists := s.tunnels.Get(tunnelName); exists {
w.WriteHeader(http.StatusConflict) w.WriteHeader(http.StatusConflict)
_, _ = w.Write([]byte("Tunnel already registered")) _, _ = w.Write([]byte("Tunnel already registered"))
return return
@ -302,26 +226,14 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
return return
} }
// Create & Cache TunnelConnection // Create Tunnel
tunnel := &TunnelConnection{ conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
Conn: wsConn, s.tunnels.Set(tunnelName, conduitTunnel)
name: tunnelName,
streams: make(map[string]chan []byte),
}
s.mu.Lock()
s.tunnels[tunnelName] = tunnel
s.mu.Unlock()
log.Infof("tunnel established: %s", tunnelName)
// Keep connection alive and handle cleanup // Start Tunnel - This is blocking
defer func() { conduitTunnel.Start(s.ctx)
s.mu.Lock()
delete(s.tunnels, tunnelName) // Cleanup Tunnel
s.mu.Unlock() s.tunnels.Delete(tunnelName)
_ = wsConn.Close() _ = wsConn.Close()
log.Infof("tunnel closed: %s", tunnelName)
}()
// Handle tunnel messages
s.handleTunnelMessages(tunnel)
} }

18
store/context.go Normal file
View File

@ -0,0 +1,18 @@
package store
import (
"context"
)
type contextKey struct{}
var recordIDKey = contextKey{}
func withRecord(ctx context.Context, rec *TunnelRecord) context.Context {
return context.WithValue(ctx, recordIDKey, rec)
}
func getRecord(ctx context.Context) (*TunnelRecord, bool) {
id, ok := ctx.Value(recordIDKey).(*TunnelRecord)
return id, ok
}

196
store/store.go Normal file
View File

@ -0,0 +1,196 @@
package store
import (
"bytes"
"errors"
"io"
"mime"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
const (
defaultQueueSize = 100
maxQueueSize = 100
)
var ErrRecordNotFound = errors.New("record not found")
type TunnelStore interface {
Get(before time.Time, count int) (results []*TunnelRecord, more bool)
RecordTCP()
RecordRequest(req *http.Request)
RecordResponse(resp *http.Response) error
}
type TunnelRecord struct {
ID uuid.UUID
Time time.Time
URL *url.URL
Method string
Status int
RequestHeaders http.Header
RequestBodyType string
RequestBody []byte
ResponseHeaders http.Header
ResponseBodyType string
ResponseBody []byte
}
func NewTunnelStore(queueSize int) TunnelStore {
if queueSize <= 0 {
queueSize = defaultQueueSize
} else if queueSize > maxQueueSize {
queueSize = maxQueueSize
}
return &tunnelStoreImpl{queueSize: queueSize}
}
type tunnelStoreImpl struct {
orderedRecords []*TunnelRecord
queueSize int
mu sync.Mutex
}
func (s *tunnelStoreImpl) Get(before time.Time, count int) ([]*TunnelRecord, bool) {
// Find First
start := -1
for i, r := range s.orderedRecords {
if r.Time.Before(before) {
start = i
break
}
}
// Not Found
if start == -1 {
return nil, false
}
// Subslice Records
end := min(start+count, len(s.orderedRecords))
results := s.orderedRecords[start:end]
more := end < len(s.orderedRecords)
return results, more
}
func (s *tunnelStoreImpl) RecordRequest(req *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
url := *req.URL
rec := &TunnelRecord{
ID: uuid.New(),
Time: time.Now(),
URL: &url,
Method: req.Method,
RequestHeaders: req.Header,
RequestBodyType: req.Header.Get("Content-Type"),
}
if bodyData, err := getRequestBody(req); err == nil {
rec.RequestBody = bodyData
}
// Add Record & Truncate
s.orderedRecords = append(s.orderedRecords, rec)
if len(s.orderedRecords) > s.queueSize {
s.orderedRecords = s.orderedRecords[len(s.orderedRecords)-s.queueSize:]
}
*req = *req.WithContext(withRecord(req.Context(), rec))
}
func (s *tunnelStoreImpl) RecordResponse(resp *http.Response) error {
rec, found := getRecord(resp.Request.Context())
if !found {
return ErrRecordNotFound
}
rec.ResponseHeaders = resp.Header
rec.ResponseBodyType = resp.Header.Get("Content-Type")
if bodyData, err := getResponseBody(resp); err == nil {
rec.ResponseBody = bodyData
}
return nil
}
func (s *tunnelStoreImpl) RecordTCP() {
s.mu.Lock()
defer s.mu.Unlock()
// TODO
}
func getRequestBody(req *http.Request) ([]byte, error) {
if req.ContentLength == 0 || req.Body == nil || req.Body == http.NoBody {
return nil, nil
}
if !isTextContentType(req.Header.Get("Content-Type")) {
return nil, nil
}
// Read Body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// Restore Body
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return bodyBytes, nil
}
func getResponseBody(resp *http.Response) ([]byte, error) {
if resp.ContentLength == 0 || resp.Body == nil || resp.Body == http.NoBody {
return nil, nil
}
if !isTextContentType(resp.Header.Get("Content-Type")) {
return nil, nil
}
// Read Body
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Restore Body
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return bodyBytes, nil
}
func isTextContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
if strings.HasPrefix(mediaType, "text/") {
return true
}
switch mediaType {
case "application/json":
return true
case "application/xml":
return true
case "application/x-www-form-urlencoded":
return true
default:
return false
}
}

43
tunnel/forwarder.go Normal file
View File

@ -0,0 +1,43 @@
package tunnel
import (
"context"
"net/url"
"reichard.io/conduit/store"
)
type ForwarderType int
const (
ForwarderTCP ForwarderType = iota
ForwarderHTTP
)
type Forwarder interface {
Type() ForwarderType
Initialize() (Stream, error)
Start(context.Context) error
}
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
// Get Target URL
targetURL, err := url.Parse(target)
if err != nil {
return nil, err
}
// Get Connection Builder
var forwarder Forwarder
switch targetURL.Scheme {
case "http", "https":
forwarder, err = newHTTPForwarder(targetURL, tunnelStore)
if err != nil {
return nil, err
}
default:
forwarder = newTCPForwarder(target, tunnelStore)
}
return forwarder, nil
}

132
tunnel/http_forwarder.go Normal file
View File

@ -0,0 +1,132 @@
package tunnel
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"reichard.io/conduit/store"
)
func newHTTPForwarder(targetURL *url.URL, tunnelStore store.TunnelStore) (Forwarder, error) {
return &httpConnBuilder{
multiConnListener: newMultiConnListener(),
tunnelStore: tunnelStore,
targetURL: targetURL,
}, nil
}
type httpConnBuilder struct {
multiConnListener *multiConnListener
tunnelStore store.TunnelStore
targetURL *url.URL
}
func (c *httpConnBuilder) Type() ForwarderType {
return ForwarderHTTP
}
func (c *httpConnBuilder) Start(ctx context.Context) error {
// Create Reverse Proxy Server
server := &http.Server{
Handler: &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.Host = c.targetURL.Host
req.URL.Host = c.targetURL.Host
req.URL.Scheme = c.targetURL.Scheme
c.tunnelStore.RecordRequest(req)
},
ModifyResponse: c.tunnelStore.RecordResponse,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
},
},
}
// Context & Cleanup
go func() {
<-ctx.Done()
server.Shutdown(ctx)
c.multiConnListener.Close()
}()
// Start HTTP Proxy
if err := server.Serve(c.multiConnListener); err != nil && err != http.ErrServerClosed {
return err
}
return nil
}
func (c *httpConnBuilder) Initialize() (Stream, error) {
clientConn, serverConn := net.Pipe()
if err := c.multiConnListener.addConn(serverConn); err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
return nil, err
}
return &streamImpl{clientConn, c.targetURL.String()}, nil
}
type multiConnListener struct {
connCh chan net.Conn
closed chan struct{}
once sync.Once
}
func newMultiConnListener() *multiConnListener {
return &multiConnListener{
connCh: make(chan net.Conn, 100),
closed: make(chan struct{}),
}
}
func (l *multiConnListener) Accept() (net.Conn, error) {
select {
case conn := <-l.connCh:
if conn == nil {
return nil, fmt.Errorf("listener closed")
}
return conn, nil
case <-l.closed:
return nil, fmt.Errorf("listener closed")
}
}
func (l *multiConnListener) Close() error {
l.once.Do(func() {
close(l.closed)
// Drain any remaining connections
go func() {
for conn := range l.connCh {
if conn != nil {
conn.Close()
}
}
}()
close(l.connCh)
})
return nil
}
func (l *multiConnListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
func (l *multiConnListener) addConn(conn net.Conn) error {
select {
case l.connCh <- conn:
return nil
case <-l.closed:
conn.Close()
return fmt.Errorf("listener is closed")
default:
conn.Close()
return fmt.Errorf("connection queue full")
}
}

View File

@ -1,4 +1,4 @@
package client package tunnel
import ( import (
"fmt" "fmt"

26
tunnel/stream.go Normal file
View File

@ -0,0 +1,26 @@
package tunnel
import (
"io"
"net"
)
var _ Stream = (*streamImpl)(nil)
type Stream interface {
io.ReadWriteCloser
Source() string
}
func NewStream(conn net.Conn, source string) Stream {
return &streamImpl{conn, source}
}
type streamImpl struct {
net.Conn
source string
}
func (s *streamImpl) Source() string {
return s.source
}

37
tunnel/tcp_forwarder.go Normal file
View File

@ -0,0 +1,37 @@
package tunnel
import (
"context"
"net"
"reichard.io/conduit/store"
)
func newTCPForwarder(target string, tunnelStore store.TunnelStore) Forwarder {
return &tcpConnBuilder{
target: target,
tunnelStore: tunnelStore,
}
}
type tcpConnBuilder struct {
target string
tunnelStore store.TunnelStore
}
func (l *tcpConnBuilder) Type() ForwarderType {
return ForwarderTCP
}
func (l *tcpConnBuilder) Initialize() (Stream, error) {
conn, err := net.Dial("tcp", l.target)
if err != nil {
return nil, err
}
return &streamImpl{conn, l.target}, nil
}
func (l *tcpConnBuilder) Start(ctx context.Context) error {
return nil
}

230
tunnel/tunnel.go Normal file
View File

@ -0,0 +1,230 @@
package tunnel
import (
"context"
"fmt"
"net/url"
"sync"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/config"
"reichard.io/conduit/pkg/maps"
"reichard.io/conduit/types"
)
// NewServerTunnel creates a new tunnel with name and websocket connection. The tunnel is
// generally instantiated after an upgrade request from the server.
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
return &Tunnel{
name: name,
streams: maps.New[string, Stream](),
wsConn: wsConn,
}
}
// NewClientTunnel creates a new tunnel with the provided configuration and forwarder. A
// forwarder is effectively the protocol being forwarded. For example HTTP (Proxy), and TCP.
func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, error) {
// Parse Server URL
serverURL, err := url.Parse(cfg.ServerAddress)
if err != nil {
return nil, err
}
// Parse Scheme
var wsScheme string
switch serverURL.Scheme {
case "https":
wsScheme = "wss"
case "http":
wsScheme = "ws"
default:
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
}
// Create Tunnel Name
if cfg.TunnelName == "" {
cfg.TunnelName = generateTunnelName()
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
}
// Connect Server WS
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to connect: %v", err)
}
return &Tunnel{
name: cfg.TunnelName,
wsConn: serverConn,
streams: maps.New[string, Stream](),
forwarder: forwarder,
}, nil
}
type Tunnel struct {
ctx context.Context
name string
wsConn *websocket.Conn
streams *maps.Map[string, Stream]
forwarder Forwarder
mu sync.Mutex
}
func (t *Tunnel) Start(ctx context.Context) {
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
t.ctx = ctx
// Start Message Receiver
for {
msg, err := t.readWSWithContext(ctx)
if err != nil {
return
}
// Validate Stream
if msg.StreamID == "" {
log.Warnf("tunnel %s missing streamID", t.name)
continue
}
// Get Stream
stream, err := t.getStream(msg.StreamID)
if err != nil {
if msg.Type != types.MessageTypeClose {
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
}
continue
}
// Handle Messages
switch msg.Type {
case types.MessageTypeClose:
_ = t.closeStream(stream, msg.StreamID)
case types.MessageTypeData:
_, err = stream.Write(msg.Data)
}
// Log Error
if err != nil {
log.WithError(err).Errorf("failed to handle message %s", msg.StreamID)
}
}
}
func (t *Tunnel) readWSWithContext(ctx context.Context) (*types.Message, error) {
type result struct {
msg *types.Message
err error
}
resultChan := make(chan result, 1)
go func() {
var msg types.Message
err := t.wsConn.ReadJSON(&msg)
resultChan <- result{&msg, err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-resultChan:
return result.msg, result.err
}
}
func (t *Tunnel) AddStream(stream Stream, streamID string) error {
if t.streams.HasKey(streamID) {
return fmt.Errorf("stream %s already exists", streamID)
}
log.Infof("tunnel %q initiated stream with %s", t.name, stream.Source())
t.streams.Set(streamID, stream)
return nil
}
func (t *Tunnel) Source() string {
return t.wsConn.RemoteAddr().String()
}
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
// Close Stream
defer t.closeStream(stream, streamID)
// Start Stream
for {
data, err := t.readStreamWithContext(t.ctx, stream)
if err != nil {
return err
}
if err := t.sendWS(&types.Message{
Type: types.MessageTypeData,
StreamID: streamID,
Data: data,
SourceAddr: stream.Source(),
}); err != nil {
return err
}
}
}
func (t *Tunnel) closeStream(stream Stream, streamID string) error {
log.Infof("tunnel %q closed stream with %s", t.name, stream.Source())
t.streams.Delete(streamID)
return stream.Close()
}
func (t *Tunnel) getStream(streamID string) (Stream, error) {
// Check Existing Stream
if stream, found := t.streams.Get(streamID); found {
return stream, nil
}
// Check Forwarder
if t.forwarder == nil {
return nil, fmt.Errorf("stream %s does not exist", streamID)
}
// Initialize Forwarder & Add Stream
stream, err := t.forwarder.Initialize()
if err != nil {
return nil, err
}
if err := t.AddStream(stream, streamID); err != nil {
return nil, err
}
go t.StartStream(stream, streamID)
return stream, nil
}
func (t *Tunnel) readStreamWithContext(ctx context.Context, stream Stream) ([]byte, error) {
type result struct {
data []byte
err error
}
resultChan := make(chan result, 1)
go func() {
buffer := make([]byte, 4096)
n, err := stream.Read(buffer)
resultChan <- result{buffer[:n], err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-resultChan:
return result.data, result.err
}
}
func (t *Tunnel) sendWS(msg *types.Message) error {
t.mu.Lock()
defer t.mu.Unlock()
return t.wsConn.WriteJSON(msg)
}

View File

@ -10,5 +10,6 @@ const (
type Message struct { type Message struct {
Type MessageType `json:"type"` Type MessageType `json:"type"`
StreamID string `json:"stream_id"` StreamID string `json:"stream_id"`
SourceAddr string `json:"source_addr"`
Data []byte `json:"data,omitempty"` Data []byte `json:"data,omitempty"`
} }