From 0722e5f032f55bc34e1ef9602fa2d221c2da8f0e Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sat, 27 Sep 2025 17:49:59 -0400 Subject: [PATCH] chore: tunnel recorder & slight refactor --- client/client.go | 45 ------- cmd/serve.go | 7 +- cmd/tunnel.go | 19 ++- config/config.go | 30 +++-- config/logging.go | 61 +++++++++ go.mod | 1 + go.sum | 2 + server/server.go | 21 +-- store/context.go | 18 +++ store/store.go | 196 +++++++++++++++++++++++++++ tunnel/forwarder.go | 43 ++++++ tunnel/http.go | 104 --------------- tunnel/http_forwarder.go | 132 ++++++++++++++++++ {client => tunnel}/name.go | 2 +- tunnel/stream.go | 26 ++++ tunnel/tcp_forwarder.go | 37 ++++++ tunnel/tunnel.go | 266 +++++++++++++++++++++---------------- 17 files changed, 725 insertions(+), 285 deletions(-) delete mode 100644 client/client.go create mode 100644 config/logging.go create mode 100644 store/context.go create mode 100644 store/store.go create mode 100644 tunnel/forwarder.go delete mode 100644 tunnel/http.go create mode 100644 tunnel/http_forwarder.go rename {client => tunnel}/name.go (97%) create mode 100644 tunnel/stream.go create mode 100644 tunnel/tcp_forwarder.go diff --git a/client/client.go b/client/client.go deleted file mode 100644 index b8cad0a..0000000 --- a/client/client.go +++ /dev/null @@ -1,45 +0,0 @@ -package client - -import ( - "fmt" - "net/url" - - "github.com/gorilla/websocket" - log "github.com/sirupsen/logrus" - "reichard.io/conduit/config" - "reichard.io/conduit/tunnel" -) - -func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) { - // Parse Server URL - serverURL, err := url.Parse(cfg.ServerAddress) - if err != nil { - return nil, err - } - - // Parse Scheme - var wsScheme string - switch serverURL.Scheme { - case "https": - wsScheme = "wss" - case "http": - wsScheme = "ws" - default: - return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) - } - - // Create Tunnel Name - if cfg.TunnelName == "" { - cfg.TunnelName = generateTunnelName() - log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName) - } - - // Connect Server WS - wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey) - serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to connect: %v", err) - } - - return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn) -} diff --git a/cmd/serve.go b/cmd/serve.go index c63a4ff..b1b4ce6 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -1,6 +1,8 @@ package cmd import ( + "context" + "reichard.io/conduit/config" "reichard.io/conduit/server" @@ -19,8 +21,11 @@ var serveCmd = &cobra.Command{ log.Fatal("failed to get server config:", err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Create Server - srv, err := server.NewServer(cfg) + srv, err := server.NewServer(ctx, cfg) if err != nil { log.Fatal("failed to create server:", err) } diff --git a/cmd/tunnel.go b/cmd/tunnel.go index 1643366..a9f48ab 100644 --- a/cmd/tunnel.go +++ b/cmd/tunnel.go @@ -1,10 +1,13 @@ package cmd import ( + "context" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "reichard.io/conduit/client" "reichard.io/conduit/config" + "reichard.io/conduit/store" + "reichard.io/conduit/tunnel" ) var tunnelCmd = &cobra.Command{ @@ -17,12 +20,22 @@ var tunnelCmd = &cobra.Command{ log.Fatal("failed to get client config:", err) } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create Forwarder + tunnelForwarder, err := tunnel.NewForwarder(cfg.TunnelTarget, store.NewTunnelStore(100)) + if err != nil { + log.Fatal("failed to create tunnel forwarder:", err) + } + go tunnelForwarder.Start(ctx) + // Create Tunnel - tunnel, err := client.NewTunnel(cfg) + tunnel, err := tunnel.NewClientTunnel(cfg, tunnelForwarder) if err != nil { log.Fatal("failed to create tunnel:", err) } - tunnel.Start() + tunnel.Start(ctx) }, } diff --git a/config/config.go b/config/config.go index 8d9aa28..7c969d4 100644 --- a/config/config.go +++ b/config/config.go @@ -23,6 +23,8 @@ type ConfigDef struct { type BaseConfig struct { ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"` APIKey string `json:"api_key" description:"API Key for the conduit API"` + LogLevel string `json:"log_level" default:"info" description:"Log level"` + LogFormat string `json:"log_format" default:"text" description:"Log format - text or json"` } func (c *BaseConfig) Validate() error { @@ -35,6 +37,9 @@ func (c *BaseConfig) Validate() error { if _, err := url.Parse(c.ServerAddress); err != nil { return fmt.Errorf("server is invalid: %w", err) } + if c.LogFormat != "text" && c.LogFormat != "json" { + return fmt.Errorf("log format must be 'text' or 'json'") + } return nil } @@ -68,13 +73,13 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) { } cfg := &ServerConfig{ - BaseConfig: BaseConfig{ - ServerAddress: cfgValues["server"], - APIKey: cfgValues["api_key"], - }, + BaseConfig: getBaseConfig(cfgValues), BindAddress: cfgValues["bind"], } + // Initialize Logger + initLogger(cfg.BaseConfig) + return cfg, cfg.Validate() } @@ -87,14 +92,14 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) { } cfg := &ClientConfig{ - BaseConfig: BaseConfig{ - ServerAddress: cfgValues["server"], - APIKey: cfgValues["api_key"], - }, + BaseConfig: getBaseConfig(cfgValues), TunnelName: cfgValues["name"], TunnelTarget: cfgValues["target"], } + // Initialize Logger + initLogger(cfg.BaseConfig) + return cfg, cfg.Validate() } @@ -108,6 +113,15 @@ func GetVersion() string { return version } +func getBaseConfig(cfgValues map[string]string) BaseConfig { + return BaseConfig{ + ServerAddress: cfgValues["server"], + APIKey: cfgValues["api_key"], + LogLevel: cfgValues["log_level"], + LogFormat: cfgValues["log_format"], + } +} + func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string { // 1. Get Flags First if cmdFlags != nil { diff --git a/config/logging.go b/config/logging.go new file mode 100644 index 0000000..018b079 --- /dev/null +++ b/config/logging.go @@ -0,0 +1,61 @@ +package config + +import ( + "fmt" + "runtime" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +func initLogger(cfg BaseConfig) { + // Parse Log Level + logLevel, err := log.ParseLevel(cfg.LogLevel) + if err != nil { + logLevel = log.InfoLevel + } + log.SetLevel(logLevel) + + // Create Log Formatter + var logFormatter log.Formatter + switch cfg.LogFormat { + case "json": + log.SetReportCaller(true) + logFormatter = &log.JSONFormatter{ + TimestampFormat: time.RFC3339, + CallerPrettyfier: prettyCaller, + } + case "text": + logFormatter = &log.TextFormatter{ + TimestampFormat: time.RFC3339, + FullTimestamp: true, + } + } + + log.SetFormatter(&utcFormatter{logFormatter}) +} + +func prettyCaller(f *runtime.Frame) (function string, file string) { + purgePrefix := "reichard.io/conduit/" + + pathName := strings.Replace(f.Func.Name(), purgePrefix, "", 1) + parts := strings.Split(pathName, ".") + + filepath, line := f.Func.FileLine(f.PC) + splitFilePath := strings.Split(filepath, "/") + + fileName := fmt.Sprintf("%s/%s@%d", parts[0], splitFilePath[len(splitFilePath)-1], line) + functionName := strings.Replace(pathName, parts[0]+".", "", 1) + + return functionName, fileName +} + +type utcFormatter struct { + log.Formatter +} + +func (cf utcFormatter) Format(e *log.Entry) ([]byte, error) { + e.Time = e.Time.UTC() + return cf.Formatter.Format(e) +} diff --git a/go.mod b/go.mod index 29ae61b..ef9f828 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module reichard.io/conduit go 1.24.4 require ( + github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect diff --git a/go.sum b/go.sum index 7cb5bd1..f72deec 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/server/server.go b/server/server.go index 5d5ec46..964c925 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "context" "encoding/json" "errors" "fmt" @@ -31,6 +32,7 @@ type TunnelInfo struct { } type Server struct { + ctx context.Context host string cfg *config.ServerConfig @@ -38,7 +40,7 @@ type Server struct { tunnels *maps.Map[string, *tunnel.Tunnel] } -func NewServer(cfg *config.ServerConfig) (*Server, error) { +func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) { serverURL, err := url.Parse(cfg.ServerAddress) if err != nil { return nil, fmt.Errorf("failed to parse server address: %v", err) @@ -47,6 +49,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) { } return &Server{ + ctx: ctx, cfg: cfg, host: serverURL.Host, tunnels: maps.New[string, *tunnel.Tunnel](), @@ -163,17 +166,21 @@ func (s *Server) handleRawConnection(conn net.Conn) { return } - // Add & Start Stream + // Create Stream reconstructedConn := newReconstructedConn(conn, &capturedData) streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) - if err := conduitTunnel.AddStream(streamID, reconstructedConn); err != nil { + tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr) + + // Add Stream + if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil { w.WriteHeader(http.StatusInternalServerError) _, _ = fmt.Fprintf(w, "failed to add stream: %v", err) + log.WithError(err).Error("failed to add stream") return } - log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr) - _ = conduitTunnel.StartStream(streamID, r.RemoteAddr) + // Start Stream + conduitTunnel.StartStream(tunnelStream, streamID) } func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { @@ -222,13 +229,11 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Create Tunnel conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) s.tunnels.Set(tunnelName, conduitTunnel) - log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr) // Start Tunnel - This is blocking - conduitTunnel.Start() + conduitTunnel.Start(s.ctx) // Cleanup Tunnel s.tunnels.Delete(tunnelName) _ = wsConn.Close() - log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr) } diff --git a/store/context.go b/store/context.go new file mode 100644 index 0000000..8cca97a --- /dev/null +++ b/store/context.go @@ -0,0 +1,18 @@ +package store + +import ( + "context" +) + +type contextKey struct{} + +var recordIDKey = contextKey{} + +func withRecord(ctx context.Context, rec *TunnelRecord) context.Context { + return context.WithValue(ctx, recordIDKey, rec) +} + +func getRecord(ctx context.Context) (*TunnelRecord, bool) { + id, ok := ctx.Value(recordIDKey).(*TunnelRecord) + return id, ok +} diff --git a/store/store.go b/store/store.go new file mode 100644 index 0000000..4fbb840 --- /dev/null +++ b/store/store.go @@ -0,0 +1,196 @@ +package store + +import ( + "bytes" + "errors" + "io" + "mime" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +const ( + defaultQueueSize = 100 + maxQueueSize = 100 +) + +var ErrRecordNotFound = errors.New("record not found") + +type TunnelStore interface { + Get(before time.Time, count int) (results []*TunnelRecord, more bool) + RecordTCP() + RecordRequest(req *http.Request) + RecordResponse(resp *http.Response) error +} + +type TunnelRecord struct { + ID uuid.UUID + Time time.Time + URL *url.URL + Method string + Status int + + RequestHeaders http.Header + RequestBodyType string + RequestBody []byte + + ResponseHeaders http.Header + ResponseBodyType string + ResponseBody []byte +} + +func NewTunnelStore(queueSize int) TunnelStore { + if queueSize <= 0 { + queueSize = defaultQueueSize + } else if queueSize > maxQueueSize { + queueSize = maxQueueSize + } + + return &tunnelStoreImpl{queueSize: queueSize} +} + +type tunnelStoreImpl struct { + orderedRecords []*TunnelRecord + queueSize int + mu sync.Mutex +} + +func (s *tunnelStoreImpl) Get(before time.Time, count int) ([]*TunnelRecord, bool) { + // Find First + start := -1 + for i, r := range s.orderedRecords { + if r.Time.Before(before) { + start = i + break + } + } + + // Not Found + if start == -1 { + return nil, false + } + + // Subslice Records + end := min(start+count, len(s.orderedRecords)) + results := s.orderedRecords[start:end] + more := end < len(s.orderedRecords) + + return results, more +} + +func (s *tunnelStoreImpl) RecordRequest(req *http.Request) { + s.mu.Lock() + defer s.mu.Unlock() + + url := *req.URL + rec := &TunnelRecord{ + ID: uuid.New(), + Time: time.Now(), + URL: &url, + Method: req.Method, + RequestHeaders: req.Header, + RequestBodyType: req.Header.Get("Content-Type"), + } + + if bodyData, err := getRequestBody(req); err == nil { + rec.RequestBody = bodyData + } + + // Add Record & Truncate + s.orderedRecords = append(s.orderedRecords, rec) + if len(s.orderedRecords) > s.queueSize { + s.orderedRecords = s.orderedRecords[len(s.orderedRecords)-s.queueSize:] + } + + *req = *req.WithContext(withRecord(req.Context(), rec)) +} + +func (s *tunnelStoreImpl) RecordResponse(resp *http.Response) error { + rec, found := getRecord(resp.Request.Context()) + if !found { + return ErrRecordNotFound + } + + rec.ResponseHeaders = resp.Header + rec.ResponseBodyType = resp.Header.Get("Content-Type") + + if bodyData, err := getResponseBody(resp); err == nil { + rec.ResponseBody = bodyData + } + + return nil +} + +func (s *tunnelStoreImpl) RecordTCP() { + s.mu.Lock() + defer s.mu.Unlock() + + // TODO +} + +func getRequestBody(req *http.Request) ([]byte, error) { + if req.ContentLength == 0 || req.Body == nil || req.Body == http.NoBody { + return nil, nil + } + + if !isTextContentType(req.Header.Get("Content-Type")) { + return nil, nil + } + + // Read Body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Restore Body + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + return bodyBytes, nil +} + +func getResponseBody(resp *http.Response) ([]byte, error) { + if resp.ContentLength == 0 || resp.Body == nil || resp.Body == http.NoBody { + return nil, nil + } + + if !isTextContentType(resp.Header.Get("Content-Type")) { + return nil, nil + } + + // Read Body + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Restore Body + resp.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + return bodyBytes, nil +} + +func isTextContentType(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + + if strings.HasPrefix(mediaType, "text/") { + return true + } + + switch mediaType { + case "application/json": + return true + case "application/xml": + return true + case "application/x-www-form-urlencoded": + return true + default: + return false + } +} diff --git a/tunnel/forwarder.go b/tunnel/forwarder.go new file mode 100644 index 0000000..92fa763 --- /dev/null +++ b/tunnel/forwarder.go @@ -0,0 +1,43 @@ +package tunnel + +import ( + "context" + "net/url" + + "reichard.io/conduit/store" +) + +type ForwarderType int + +const ( + ForwarderTCP ForwarderType = iota + ForwarderHTTP +) + +type Forwarder interface { + Type() ForwarderType + Initialize() (Stream, error) + Start(context.Context) error +} + +func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) { + // Get Target URL + targetURL, err := url.Parse(target) + if err != nil { + return nil, err + } + + // Get Connection Builder + var forwarder Forwarder + switch targetURL.Scheme { + case "http", "https": + forwarder, err = newHTTPForwarder(targetURL, tunnelStore) + if err != nil { + return nil, err + } + default: + forwarder = newTCPForwarder(target, tunnelStore) + } + + return forwarder, nil +} diff --git a/tunnel/http.go b/tunnel/http.go deleted file mode 100644 index dfb90e3..0000000 --- a/tunnel/http.go +++ /dev/null @@ -1,104 +0,0 @@ -package tunnel - -import ( - "fmt" - "io" - "net" - "net/http" - "net/http/httputil" - "net/url" - "sync" -) - -func HTTPConnectionBuilder(targetURL *url.URL) (ConnBuilder, error) { - multiConnListener := newMultiConnListener() - - // Create Reverse Proxy - proxy := &httputil.ReverseProxy{ - Director: func(req *http.Request) { - req.Host = targetURL.Host - req.URL.Host = targetURL.Host - req.URL.Scheme = targetURL.Scheme - }, - ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { - http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway) - }, - } - - // Start HTTP Proxy - go func() { - defer multiConnListener.Close() - _ = http.Serve(multiConnListener, proxy) - }() - - // Return Connection Builder - return func() (conn io.ReadWriteCloser, err error) { - clientConn, serverConn := net.Pipe() - - if err := multiConnListener.addConn(serverConn); err != nil { - _ = clientConn.Close() - _ = serverConn.Close() - return nil, err - } - - return clientConn, nil - }, nil -} - -type multiConnListener struct { - connCh chan net.Conn - closed chan struct{} - once sync.Once -} - -func newMultiConnListener() *multiConnListener { - return &multiConnListener{ - connCh: make(chan net.Conn, 100), - closed: make(chan struct{}), - } -} - -func (l *multiConnListener) Accept() (net.Conn, error) { - select { - case conn := <-l.connCh: - if conn == nil { - return nil, fmt.Errorf("listener closed") - } - return conn, nil - case <-l.closed: - return nil, fmt.Errorf("listener closed") - } -} - -func (l *multiConnListener) Close() error { - l.once.Do(func() { - close(l.closed) - // Drain any remaining connections - go func() { - for conn := range l.connCh { - if conn != nil { - conn.Close() - } - } - }() - close(l.connCh) - }) - return nil -} - -func (l *multiConnListener) Addr() net.Addr { - return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} -} - -func (l *multiConnListener) addConn(conn net.Conn) error { - select { - case l.connCh <- conn: - return nil - case <-l.closed: - conn.Close() - return fmt.Errorf("listener is closed") - default: - conn.Close() - return fmt.Errorf("connection queue full") - } -} diff --git a/tunnel/http_forwarder.go b/tunnel/http_forwarder.go new file mode 100644 index 0000000..f87fc43 --- /dev/null +++ b/tunnel/http_forwarder.go @@ -0,0 +1,132 @@ +package tunnel + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + + "reichard.io/conduit/store" +) + +func newHTTPForwarder(targetURL *url.URL, tunnelStore store.TunnelStore) (Forwarder, error) { + return &httpConnBuilder{ + multiConnListener: newMultiConnListener(), + tunnelStore: tunnelStore, + targetURL: targetURL, + }, nil +} + +type httpConnBuilder struct { + multiConnListener *multiConnListener + tunnelStore store.TunnelStore + targetURL *url.URL +} + +func (c *httpConnBuilder) Type() ForwarderType { + return ForwarderHTTP +} + +func (c *httpConnBuilder) Start(ctx context.Context) error { + // Create Reverse Proxy Server + server := &http.Server{ + Handler: &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.Host = c.targetURL.Host + req.URL.Host = c.targetURL.Host + req.URL.Scheme = c.targetURL.Scheme + c.tunnelStore.RecordRequest(req) + }, + ModifyResponse: c.tunnelStore.RecordResponse, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway) + }, + }, + } + + // Context & Cleanup + go func() { + <-ctx.Done() + server.Shutdown(ctx) + c.multiConnListener.Close() + }() + + // Start HTTP Proxy + if err := server.Serve(c.multiConnListener); err != nil && err != http.ErrServerClosed { + return err + } + return nil +} + +func (c *httpConnBuilder) Initialize() (Stream, error) { + clientConn, serverConn := net.Pipe() + + if err := c.multiConnListener.addConn(serverConn); err != nil { + _ = clientConn.Close() + _ = serverConn.Close() + return nil, err + } + + return &streamImpl{clientConn, c.targetURL.String()}, nil +} + +type multiConnListener struct { + connCh chan net.Conn + closed chan struct{} + once sync.Once +} + +func newMultiConnListener() *multiConnListener { + return &multiConnListener{ + connCh: make(chan net.Conn, 100), + closed: make(chan struct{}), + } +} + +func (l *multiConnListener) Accept() (net.Conn, error) { + select { + case conn := <-l.connCh: + if conn == nil { + return nil, fmt.Errorf("listener closed") + } + return conn, nil + case <-l.closed: + return nil, fmt.Errorf("listener closed") + } +} + +func (l *multiConnListener) Close() error { + l.once.Do(func() { + close(l.closed) + // Drain any remaining connections + go func() { + for conn := range l.connCh { + if conn != nil { + conn.Close() + } + } + }() + close(l.connCh) + }) + return nil +} + +func (l *multiConnListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (l *multiConnListener) addConn(conn net.Conn) error { + select { + case l.connCh <- conn: + return nil + case <-l.closed: + conn.Close() + return fmt.Errorf("listener is closed") + default: + conn.Close() + return fmt.Errorf("connection queue full") + } +} diff --git a/client/name.go b/tunnel/name.go similarity index 97% rename from client/name.go rename to tunnel/name.go index 97aba3e..2f24d57 100644 --- a/client/name.go +++ b/tunnel/name.go @@ -1,4 +1,4 @@ -package client +package tunnel import ( "fmt" diff --git a/tunnel/stream.go b/tunnel/stream.go new file mode 100644 index 0000000..affb4c2 --- /dev/null +++ b/tunnel/stream.go @@ -0,0 +1,26 @@ +package tunnel + +import ( + "io" + "net" +) + +var _ Stream = (*streamImpl)(nil) + +type Stream interface { + io.ReadWriteCloser + Source() string +} + +func NewStream(conn net.Conn, source string) Stream { + return &streamImpl{conn, source} +} + +type streamImpl struct { + net.Conn + source string +} + +func (s *streamImpl) Source() string { + return s.source +} diff --git a/tunnel/tcp_forwarder.go b/tunnel/tcp_forwarder.go new file mode 100644 index 0000000..32153ab --- /dev/null +++ b/tunnel/tcp_forwarder.go @@ -0,0 +1,37 @@ +package tunnel + +import ( + "context" + "net" + + "reichard.io/conduit/store" +) + +func newTCPForwarder(target string, tunnelStore store.TunnelStore) Forwarder { + return &tcpConnBuilder{ + target: target, + tunnelStore: tunnelStore, + } +} + +type tcpConnBuilder struct { + target string + tunnelStore store.TunnelStore +} + +func (l *tcpConnBuilder) Type() ForwarderType { + return ForwarderTCP +} + +func (l *tcpConnBuilder) Initialize() (Stream, error) { + conn, err := net.Dial("tcp", l.target) + if err != nil { + return nil, err + } + + return &streamImpl{conn, l.target}, nil +} + +func (l *tcpConnBuilder) Start(ctx context.Context) error { + return nil +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index cf54ab8..6fe48e3 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,76 +1,88 @@ package tunnel import ( + "context" "fmt" - "io" - "net" "net/url" "sync" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" + "reichard.io/conduit/config" "reichard.io/conduit/pkg/maps" "reichard.io/conduit/types" ) -type ConnBuilder func() (conn io.ReadWriteCloser, err error) - +// NewServerTunnel creates a new tunnel with name and websocket connection. The tunnel is +// generally instantiated after an upgrade request from the server. func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { return &Tunnel{ name: name, - streams: maps.New[string, io.ReadWriteCloser](), + streams: maps.New[string, Stream](), wsConn: wsConn, } } -func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) { - // Get Target URL - targetURL, err := url.Parse(target) +// NewClientTunnel creates a new tunnel with the provided configuration and forwarder. A +// forwarder is effectively the protocol being forwarded. For example HTTP (Proxy), and TCP. +func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, error) { + // Parse Server URL + serverURL, err := url.Parse(cfg.ServerAddress) if err != nil { return nil, err } - // Derive Conduit URL - conduitURL := *serverURL - conduitURL.Host = name + "." + conduitURL.Host - - // Get Connection Builder - var connBuilder ConnBuilder - switch targetURL.Scheme { - case "http", "https": - log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target) - connBuilder, err = HTTPConnectionBuilder(targetURL) - if err != nil { - return nil, err - } + // Parse Scheme + var wsScheme string + switch serverURL.Scheme { + case "https": + wsScheme = "wss" + case "http": + wsScheme = "ws" default: - log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target) - connBuilder = func() (conn io.ReadWriteCloser, err error) { - return net.Dial("tcp", target) - } + return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme) + } + + // Create Tunnel Name + if cfg.TunnelName == "" { + cfg.TunnelName = generateTunnelName() + log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName) + } + + // Connect Server WS + wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey) + serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) } return &Tunnel{ - name: name, - wsConn: wsConn, - streams: maps.New[string, io.ReadWriteCloser](), - connBuilder: connBuilder, + name: cfg.TunnelName, + wsConn: serverConn, + streams: maps.New[string, Stream](), + forwarder: forwarder, }, nil } type Tunnel struct { - name string - wsConn *websocket.Conn - streams *maps.Map[string, io.ReadWriteCloser] - connBuilder ConnBuilder + ctx context.Context + name string + wsConn *websocket.Conn + streams *maps.Map[string, Stream] + forwarder Forwarder mu sync.Mutex } -func (t *Tunnel) Start() { +func (t *Tunnel) Start(ctx context.Context) { + log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String()) + defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String()) + + t.ctx = ctx + + // Start Message Receiver for { - var msg types.Message - err := t.wsConn.ReadJSON(&msg) + msg, err := t.readWSWithContext(ctx) if err != nil { return } @@ -81,105 +93,57 @@ func (t *Tunnel) Start() { continue } - // Ensure Stream - if err := t.initStreamConnection(msg.StreamID); err != nil { - log.WithError(err).Errorf("failed to initialize stream %s connection", t.name) + // Get Stream + stream, err := t.getStream(msg.StreamID) + if err != nil { + if msg.Type != types.MessageTypeClose { + log.WithError(err).Errorf("failed to get stream %s", msg.StreamID) + } continue } // Handle Messages switch msg.Type { case types.MessageTypeClose: - _ = t.CloseStream(msg.StreamID) + _ = t.closeStream(stream, msg.StreamID) case types.MessageTypeData: - _ = t.WriteStream(msg.StreamID, msg.Data) + _, err = stream.Write(msg.Data) + } + + // Log Error + if err != nil { + log.WithError(err).Errorf("failed to handle message %s", msg.StreamID) } } } -func (t *Tunnel) initStreamConnection(streamID string) error { - if t.connBuilder == nil { - return nil +func (t *Tunnel) readWSWithContext(ctx context.Context) (*types.Message, error) { + type result struct { + msg *types.Message + err error } - if _, found := t.streams.Get(streamID); found { - return nil - } + resultChan := make(chan result, 1) + go func() { + var msg types.Message + err := t.wsConn.ReadJSON(&msg) + resultChan <- result{&msg, err} + }() - conn, err := t.connBuilder() - if err != nil { - return err + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultChan: + return result.msg, result.err } - - if err := t.AddStream(streamID, conn); err != nil { - return err - } - - go t.StartStream(streamID, "") - return nil } -func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error { +func (t *Tunnel) AddStream(stream Stream, streamID string) error { if t.streams.HasKey(streamID) { return fmt.Errorf("stream %s already exists", streamID) } - t.streams.Set(streamID, conn) - return nil -} - -func (t *Tunnel) StartStream(streamID string, sourceAddr string) error { - // Get Stream - conn, found := t.streams.Get(streamID) - if !found { - return fmt.Errorf("stream %s does not exist", streamID) - } - - // Close Stream - defer func() { - _ = t.sendWS(&types.Message{ - Type: types.MessageTypeClose, - StreamID: streamID, - SourceAddr: sourceAddr, - }) - - t.CloseStream(streamID) - }() - - // Start Stream - buffer := make([]byte, 4096) - for { - n, err := conn.Read(buffer) - if err != nil { - return err - } - - if err := t.sendWS(&types.Message{ - Type: types.MessageTypeData, - StreamID: streamID, - Data: buffer[:n], - SourceAddr: sourceAddr, - }); err != nil { - return err - } - } -} - -func (t *Tunnel) WriteStream(streamID string, data []byte) error { - // Get Stream - conn, found := t.streams.Get(streamID) - if !found { - return fmt.Errorf("stream %s does not exist", streamID) - } - - _, err := conn.Write(data) - return err -} - -func (t *Tunnel) CloseStream(streamID string) error { - if conn, ok := t.streams.Get(streamID); ok { - t.streams.Delete(streamID) - return conn.Close() - } + log.Infof("tunnel %q initiated stream with %s", t.name, stream.Source()) + t.streams.Set(streamID, stream) return nil } @@ -187,6 +151,78 @@ func (t *Tunnel) Source() string { return t.wsConn.RemoteAddr().String() } +func (t *Tunnel) StartStream(stream Stream, streamID string) error { + // Close Stream + defer t.closeStream(stream, streamID) + + // Start Stream + for { + data, err := t.readStreamWithContext(t.ctx, stream) + if err != nil { + return err + } + + if err := t.sendWS(&types.Message{ + Type: types.MessageTypeData, + StreamID: streamID, + Data: data, + SourceAddr: stream.Source(), + }); err != nil { + return err + } + } +} + +func (t *Tunnel) closeStream(stream Stream, streamID string) error { + log.Infof("tunnel %q closed stream with %s", t.name, stream.Source()) + t.streams.Delete(streamID) + return stream.Close() +} + +func (t *Tunnel) getStream(streamID string) (Stream, error) { + // Check Existing Stream + if stream, found := t.streams.Get(streamID); found { + return stream, nil + } + + // Check Forwarder + if t.forwarder == nil { + return nil, fmt.Errorf("stream %s does not exist", streamID) + } + + // Initialize Forwarder & Add Stream + stream, err := t.forwarder.Initialize() + if err != nil { + return nil, err + } + if err := t.AddStream(stream, streamID); err != nil { + return nil, err + } + go t.StartStream(stream, streamID) + return stream, nil +} + +func (t *Tunnel) readStreamWithContext(ctx context.Context, stream Stream) ([]byte, error) { + type result struct { + data []byte + err error + } + + resultChan := make(chan result, 1) + go func() { + buffer := make([]byte, 4096) + n, err := stream.Read(buffer) + resultChan <- result{buffer[:n], err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultChan: + return result.data, result.err + } +} + func (t *Tunnel) sendWS(msg *types.Message) error { t.mu.Lock() defer t.mu.Unlock()