chore: tunnel recorder & slight refactor
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
20c1388cf4
commit
0722e5f032
@ -1,45 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"reichard.io/conduit/config"
|
|
||||||
"reichard.io/conduit/tunnel"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
|
|
||||||
// Parse Server URL
|
|
||||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse Scheme
|
|
||||||
var wsScheme string
|
|
||||||
switch serverURL.Scheme {
|
|
||||||
case "https":
|
|
||||||
wsScheme = "wss"
|
|
||||||
case "http":
|
|
||||||
wsScheme = "ws"
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Tunnel Name
|
|
||||||
if cfg.TunnelName == "" {
|
|
||||||
cfg.TunnelName = generateTunnelName()
|
|
||||||
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect Server WS
|
|
||||||
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
|
|
||||||
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn)
|
|
||||||
}
|
|
@ -1,6 +1,8 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"reichard.io/conduit/config"
|
"reichard.io/conduit/config"
|
||||||
"reichard.io/conduit/server"
|
"reichard.io/conduit/server"
|
||||||
|
|
||||||
@ -19,8 +21,11 @@ var serveCmd = &cobra.Command{
|
|||||||
log.Fatal("failed to get server config:", err)
|
log.Fatal("failed to get server config:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// Create Server
|
// Create Server
|
||||||
srv, err := server.NewServer(cfg)
|
srv, err := server.NewServer(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to create server:", err)
|
log.Fatal("failed to create server:", err)
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"reichard.io/conduit/client"
|
|
||||||
"reichard.io/conduit/config"
|
"reichard.io/conduit/config"
|
||||||
|
"reichard.io/conduit/store"
|
||||||
|
"reichard.io/conduit/tunnel"
|
||||||
)
|
)
|
||||||
|
|
||||||
var tunnelCmd = &cobra.Command{
|
var tunnelCmd = &cobra.Command{
|
||||||
@ -17,12 +20,22 @@ var tunnelCmd = &cobra.Command{
|
|||||||
log.Fatal("failed to get client config:", err)
|
log.Fatal("failed to get client config:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create Forwarder
|
||||||
|
tunnelForwarder, err := tunnel.NewForwarder(cfg.TunnelTarget, store.NewTunnelStore(100))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("failed to create tunnel forwarder:", err)
|
||||||
|
}
|
||||||
|
go tunnelForwarder.Start(ctx)
|
||||||
|
|
||||||
// Create Tunnel
|
// Create Tunnel
|
||||||
tunnel, err := client.NewTunnel(cfg)
|
tunnel, err := tunnel.NewClientTunnel(cfg, tunnelForwarder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to create tunnel:", err)
|
log.Fatal("failed to create tunnel:", err)
|
||||||
}
|
}
|
||||||
tunnel.Start()
|
tunnel.Start(ctx)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ type ConfigDef struct {
|
|||||||
type BaseConfig struct {
|
type BaseConfig struct {
|
||||||
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
|
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
|
||||||
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
||||||
|
LogLevel string `json:"log_level" default:"info" description:"Log level"`
|
||||||
|
LogFormat string `json:"log_format" default:"text" description:"Log format - text or json"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *BaseConfig) Validate() error {
|
func (c *BaseConfig) Validate() error {
|
||||||
@ -35,6 +37,9 @@ func (c *BaseConfig) Validate() error {
|
|||||||
if _, err := url.Parse(c.ServerAddress); err != nil {
|
if _, err := url.Parse(c.ServerAddress); err != nil {
|
||||||
return fmt.Errorf("server is invalid: %w", err)
|
return fmt.Errorf("server is invalid: %w", err)
|
||||||
}
|
}
|
||||||
|
if c.LogFormat != "text" && c.LogFormat != "json" {
|
||||||
|
return fmt.Errorf("log format must be 'text' or 'json'")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,13 +73,13 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := &ServerConfig{
|
cfg := &ServerConfig{
|
||||||
BaseConfig: BaseConfig{
|
BaseConfig: getBaseConfig(cfgValues),
|
||||||
ServerAddress: cfgValues["server"],
|
|
||||||
APIKey: cfgValues["api_key"],
|
|
||||||
},
|
|
||||||
BindAddress: cfgValues["bind"],
|
BindAddress: cfgValues["bind"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Logger
|
||||||
|
initLogger(cfg.BaseConfig)
|
||||||
|
|
||||||
return cfg, cfg.Validate()
|
return cfg, cfg.Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,14 +92,14 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := &ClientConfig{
|
cfg := &ClientConfig{
|
||||||
BaseConfig: BaseConfig{
|
BaseConfig: getBaseConfig(cfgValues),
|
||||||
ServerAddress: cfgValues["server"],
|
|
||||||
APIKey: cfgValues["api_key"],
|
|
||||||
},
|
|
||||||
TunnelName: cfgValues["name"],
|
TunnelName: cfgValues["name"],
|
||||||
TunnelTarget: cfgValues["target"],
|
TunnelTarget: cfgValues["target"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Logger
|
||||||
|
initLogger(cfg.BaseConfig)
|
||||||
|
|
||||||
return cfg, cfg.Validate()
|
return cfg, cfg.Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,6 +113,15 @@ func GetVersion() string {
|
|||||||
return version
|
return version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseConfig(cfgValues map[string]string) BaseConfig {
|
||||||
|
return BaseConfig{
|
||||||
|
ServerAddress: cfgValues["server"],
|
||||||
|
APIKey: cfgValues["api_key"],
|
||||||
|
LogLevel: cfgValues["log_level"],
|
||||||
|
LogFormat: cfgValues["log_format"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
|
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
|
||||||
// 1. Get Flags First
|
// 1. Get Flags First
|
||||||
if cmdFlags != nil {
|
if cmdFlags != nil {
|
||||||
|
61
config/logging.go
Normal file
61
config/logging.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initLogger(cfg BaseConfig) {
|
||||||
|
// Parse Log Level
|
||||||
|
logLevel, err := log.ParseLevel(cfg.LogLevel)
|
||||||
|
if err != nil {
|
||||||
|
logLevel = log.InfoLevel
|
||||||
|
}
|
||||||
|
log.SetLevel(logLevel)
|
||||||
|
|
||||||
|
// Create Log Formatter
|
||||||
|
var logFormatter log.Formatter
|
||||||
|
switch cfg.LogFormat {
|
||||||
|
case "json":
|
||||||
|
log.SetReportCaller(true)
|
||||||
|
logFormatter = &log.JSONFormatter{
|
||||||
|
TimestampFormat: time.RFC3339,
|
||||||
|
CallerPrettyfier: prettyCaller,
|
||||||
|
}
|
||||||
|
case "text":
|
||||||
|
logFormatter = &log.TextFormatter{
|
||||||
|
TimestampFormat: time.RFC3339,
|
||||||
|
FullTimestamp: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetFormatter(&utcFormatter{logFormatter})
|
||||||
|
}
|
||||||
|
|
||||||
|
func prettyCaller(f *runtime.Frame) (function string, file string) {
|
||||||
|
purgePrefix := "reichard.io/conduit/"
|
||||||
|
|
||||||
|
pathName := strings.Replace(f.Func.Name(), purgePrefix, "", 1)
|
||||||
|
parts := strings.Split(pathName, ".")
|
||||||
|
|
||||||
|
filepath, line := f.Func.FileLine(f.PC)
|
||||||
|
splitFilePath := strings.Split(filepath, "/")
|
||||||
|
|
||||||
|
fileName := fmt.Sprintf("%s/%s@%d", parts[0], splitFilePath[len(splitFilePath)-1], line)
|
||||||
|
functionName := strings.Replace(pathName, parts[0]+".", "", 1)
|
||||||
|
|
||||||
|
return functionName, fileName
|
||||||
|
}
|
||||||
|
|
||||||
|
type utcFormatter struct {
|
||||||
|
log.Formatter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cf utcFormatter) Format(e *log.Entry) ([]byte, error) {
|
||||||
|
e.Time = e.Time.UTC()
|
||||||
|
return cf.Formatter.Format(e)
|
||||||
|
}
|
1
go.mod
1
go.mod
@ -3,6 +3,7 @@ module reichard.io/conduit
|
|||||||
go 1.24.4
|
go 1.24.4
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/gorilla/websocket v1.5.3 // indirect
|
github.com/gorilla/websocket v1.5.3 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -1,6 +1,8 @@
|
|||||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -31,6 +32,7 @@ type TunnelInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
|
ctx context.Context
|
||||||
host string
|
host string
|
||||||
cfg *config.ServerConfig
|
cfg *config.ServerConfig
|
||||||
|
|
||||||
@ -38,7 +40,7 @@ type Server struct {
|
|||||||
tunnels *maps.Map[string, *tunnel.Tunnel]
|
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
||||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse server address: %v", err)
|
return nil, fmt.Errorf("failed to parse server address: %v", err)
|
||||||
@ -47,6 +49,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Server{
|
return &Server{
|
||||||
|
ctx: ctx,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
host: serverURL.Host,
|
host: serverURL.Host,
|
||||||
tunnels: maps.New[string, *tunnel.Tunnel](),
|
tunnels: maps.New[string, *tunnel.Tunnel](),
|
||||||
@ -163,17 +166,21 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add & Start Stream
|
// Create Stream
|
||||||
reconstructedConn := newReconstructedConn(conn, &capturedData)
|
reconstructedConn := newReconstructedConn(conn, &capturedData)
|
||||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||||
if err := conduitTunnel.AddStream(streamID, reconstructedConn); err != nil {
|
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr)
|
||||||
|
|
||||||
|
// Add Stream
|
||||||
|
if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
_, _ = fmt.Fprintf(w, "failed to add stream: %v", err)
|
_, _ = fmt.Fprintf(w, "failed to add stream: %v", err)
|
||||||
|
log.WithError(err).Error("failed to add stream")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr)
|
// Start Stream
|
||||||
_ = conduitTunnel.StartStream(streamID, r.RemoteAddr)
|
conduitTunnel.StartStream(tunnelStream, streamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -222,13 +229,11 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Create Tunnel
|
// Create Tunnel
|
||||||
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
||||||
s.tunnels.Set(tunnelName, conduitTunnel)
|
s.tunnels.Set(tunnelName, conduitTunnel)
|
||||||
log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr)
|
|
||||||
|
|
||||||
// Start Tunnel - This is blocking
|
// Start Tunnel - This is blocking
|
||||||
conduitTunnel.Start()
|
conduitTunnel.Start(s.ctx)
|
||||||
|
|
||||||
// Cleanup Tunnel
|
// Cleanup Tunnel
|
||||||
s.tunnels.Delete(tunnelName)
|
s.tunnels.Delete(tunnelName)
|
||||||
_ = wsConn.Close()
|
_ = wsConn.Close()
|
||||||
log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr)
|
|
||||||
}
|
}
|
||||||
|
18
store/context.go
Normal file
18
store/context.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey struct{}
|
||||||
|
|
||||||
|
var recordIDKey = contextKey{}
|
||||||
|
|
||||||
|
func withRecord(ctx context.Context, rec *TunnelRecord) context.Context {
|
||||||
|
return context.WithValue(ctx, recordIDKey, rec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRecord(ctx context.Context) (*TunnelRecord, bool) {
|
||||||
|
id, ok := ctx.Value(recordIDKey).(*TunnelRecord)
|
||||||
|
return id, ok
|
||||||
|
}
|
196
store/store.go
Normal file
196
store/store.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultQueueSize = 100
|
||||||
|
maxQueueSize = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrRecordNotFound = errors.New("record not found")
|
||||||
|
|
||||||
|
type TunnelStore interface {
|
||||||
|
Get(before time.Time, count int) (results []*TunnelRecord, more bool)
|
||||||
|
RecordTCP()
|
||||||
|
RecordRequest(req *http.Request)
|
||||||
|
RecordResponse(resp *http.Response) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelRecord struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Time time.Time
|
||||||
|
URL *url.URL
|
||||||
|
Method string
|
||||||
|
Status int
|
||||||
|
|
||||||
|
RequestHeaders http.Header
|
||||||
|
RequestBodyType string
|
||||||
|
RequestBody []byte
|
||||||
|
|
||||||
|
ResponseHeaders http.Header
|
||||||
|
ResponseBodyType string
|
||||||
|
ResponseBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTunnelStore(queueSize int) TunnelStore {
|
||||||
|
if queueSize <= 0 {
|
||||||
|
queueSize = defaultQueueSize
|
||||||
|
} else if queueSize > maxQueueSize {
|
||||||
|
queueSize = maxQueueSize
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tunnelStoreImpl{queueSize: queueSize}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelStoreImpl struct {
|
||||||
|
orderedRecords []*TunnelRecord
|
||||||
|
queueSize int
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStoreImpl) Get(before time.Time, count int) ([]*TunnelRecord, bool) {
|
||||||
|
// Find First
|
||||||
|
start := -1
|
||||||
|
for i, r := range s.orderedRecords {
|
||||||
|
if r.Time.Before(before) {
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not Found
|
||||||
|
if start == -1 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subslice Records
|
||||||
|
end := min(start+count, len(s.orderedRecords))
|
||||||
|
results := s.orderedRecords[start:end]
|
||||||
|
more := end < len(s.orderedRecords)
|
||||||
|
|
||||||
|
return results, more
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStoreImpl) RecordRequest(req *http.Request) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
url := *req.URL
|
||||||
|
rec := &TunnelRecord{
|
||||||
|
ID: uuid.New(),
|
||||||
|
Time: time.Now(),
|
||||||
|
URL: &url,
|
||||||
|
Method: req.Method,
|
||||||
|
RequestHeaders: req.Header,
|
||||||
|
RequestBodyType: req.Header.Get("Content-Type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if bodyData, err := getRequestBody(req); err == nil {
|
||||||
|
rec.RequestBody = bodyData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Record & Truncate
|
||||||
|
s.orderedRecords = append(s.orderedRecords, rec)
|
||||||
|
if len(s.orderedRecords) > s.queueSize {
|
||||||
|
s.orderedRecords = s.orderedRecords[len(s.orderedRecords)-s.queueSize:]
|
||||||
|
}
|
||||||
|
|
||||||
|
*req = *req.WithContext(withRecord(req.Context(), rec))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStoreImpl) RecordResponse(resp *http.Response) error {
|
||||||
|
rec, found := getRecord(resp.Request.Context())
|
||||||
|
if !found {
|
||||||
|
return ErrRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
rec.ResponseHeaders = resp.Header
|
||||||
|
rec.ResponseBodyType = resp.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if bodyData, err := getResponseBody(resp); err == nil {
|
||||||
|
rec.ResponseBody = bodyData
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tunnelStoreImpl) RecordTCP() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRequestBody(req *http.Request) ([]byte, error) {
|
||||||
|
if req.ContentLength == 0 || req.Body == nil || req.Body == http.NoBody {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isTextContentType(req.Header.Get("Content-Type")) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read Body
|
||||||
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore Body
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getResponseBody(resp *http.Response) ([]byte, error) {
|
||||||
|
if resp.ContentLength == 0 || resp.Body == nil || resp.Body == http.NoBody {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isTextContentType(resp.Header.Get("Content-Type")) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read Body
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore Body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTextContentType(contentType string) bool {
|
||||||
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(mediaType, "text/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch mediaType {
|
||||||
|
case "application/json":
|
||||||
|
return true
|
||||||
|
case "application/xml":
|
||||||
|
return true
|
||||||
|
case "application/x-www-form-urlencoded":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
43
tunnel/forwarder.go
Normal file
43
tunnel/forwarder.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"reichard.io/conduit/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ForwarderType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ForwarderTCP ForwarderType = iota
|
||||||
|
ForwarderHTTP
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder interface {
|
||||||
|
Type() ForwarderType
|
||||||
|
Initialize() (Stream, error)
|
||||||
|
Start(context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||||
|
// Get Target URL
|
||||||
|
targetURL, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Connection Builder
|
||||||
|
var forwarder Forwarder
|
||||||
|
switch targetURL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
forwarder, err = newHTTPForwarder(targetURL, tunnelStore)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
forwarder = newTCPForwarder(target, tunnelStore)
|
||||||
|
}
|
||||||
|
|
||||||
|
return forwarder, nil
|
||||||
|
}
|
104
tunnel/http.go
104
tunnel/http.go
@ -1,104 +0,0 @@
|
|||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HTTPConnectionBuilder(targetURL *url.URL) (ConnBuilder, error) {
|
|
||||||
multiConnListener := newMultiConnListener()
|
|
||||||
|
|
||||||
// Create Reverse Proxy
|
|
||||||
proxy := &httputil.ReverseProxy{
|
|
||||||
Director: func(req *http.Request) {
|
|
||||||
req.Host = targetURL.Host
|
|
||||||
req.URL.Host = targetURL.Host
|
|
||||||
req.URL.Scheme = targetURL.Scheme
|
|
||||||
},
|
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start HTTP Proxy
|
|
||||||
go func() {
|
|
||||||
defer multiConnListener.Close()
|
|
||||||
_ = http.Serve(multiConnListener, proxy)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Return Connection Builder
|
|
||||||
return func() (conn io.ReadWriteCloser, err error) {
|
|
||||||
clientConn, serverConn := net.Pipe()
|
|
||||||
|
|
||||||
if err := multiConnListener.addConn(serverConn); err != nil {
|
|
||||||
_ = clientConn.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientConn, nil
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type multiConnListener struct {
|
|
||||||
connCh chan net.Conn
|
|
||||||
closed chan struct{}
|
|
||||||
once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMultiConnListener() *multiConnListener {
|
|
||||||
return &multiConnListener{
|
|
||||||
connCh: make(chan net.Conn, 100),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *multiConnListener) Accept() (net.Conn, error) {
|
|
||||||
select {
|
|
||||||
case conn := <-l.connCh:
|
|
||||||
if conn == nil {
|
|
||||||
return nil, fmt.Errorf("listener closed")
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
case <-l.closed:
|
|
||||||
return nil, fmt.Errorf("listener closed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *multiConnListener) Close() error {
|
|
||||||
l.once.Do(func() {
|
|
||||||
close(l.closed)
|
|
||||||
// Drain any remaining connections
|
|
||||||
go func() {
|
|
||||||
for conn := range l.connCh {
|
|
||||||
if conn != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
close(l.connCh)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *multiConnListener) Addr() net.Addr {
|
|
||||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *multiConnListener) addConn(conn net.Conn) error {
|
|
||||||
select {
|
|
||||||
case l.connCh <- conn:
|
|
||||||
return nil
|
|
||||||
case <-l.closed:
|
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("listener is closed")
|
|
||||||
default:
|
|
||||||
conn.Close()
|
|
||||||
return fmt.Errorf("connection queue full")
|
|
||||||
}
|
|
||||||
}
|
|
132
tunnel/http_forwarder.go
Normal file
132
tunnel/http_forwarder.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"reichard.io/conduit/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHTTPForwarder(targetURL *url.URL, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||||
|
return &httpConnBuilder{
|
||||||
|
multiConnListener: newMultiConnListener(),
|
||||||
|
tunnelStore: tunnelStore,
|
||||||
|
targetURL: targetURL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpConnBuilder struct {
|
||||||
|
multiConnListener *multiConnListener
|
||||||
|
tunnelStore store.TunnelStore
|
||||||
|
targetURL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *httpConnBuilder) Type() ForwarderType {
|
||||||
|
return ForwarderHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *httpConnBuilder) Start(ctx context.Context) error {
|
||||||
|
// Create Reverse Proxy Server
|
||||||
|
server := &http.Server{
|
||||||
|
Handler: &httputil.ReverseProxy{
|
||||||
|
Director: func(req *http.Request) {
|
||||||
|
req.Host = c.targetURL.Host
|
||||||
|
req.URL.Host = c.targetURL.Host
|
||||||
|
req.URL.Scheme = c.targetURL.Scheme
|
||||||
|
c.tunnelStore.RecordRequest(req)
|
||||||
|
},
|
||||||
|
ModifyResponse: c.tunnelStore.RecordResponse,
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context & Cleanup
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
c.multiConnListener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start HTTP Proxy
|
||||||
|
if err := server.Serve(c.multiConnListener); err != nil && err != http.ErrServerClosed {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *httpConnBuilder) Initialize() (Stream, error) {
|
||||||
|
clientConn, serverConn := net.Pipe()
|
||||||
|
|
||||||
|
if err := c.multiConnListener.addConn(serverConn); err != nil {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
_ = serverConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &streamImpl{clientConn, c.targetURL.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiConnListener struct {
|
||||||
|
connCh chan net.Conn
|
||||||
|
closed chan struct{}
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMultiConnListener() *multiConnListener {
|
||||||
|
return &multiConnListener{
|
||||||
|
connCh: make(chan net.Conn, 100),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *multiConnListener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case conn := <-l.connCh:
|
||||||
|
if conn == nil {
|
||||||
|
return nil, fmt.Errorf("listener closed")
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
case <-l.closed:
|
||||||
|
return nil, fmt.Errorf("listener closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *multiConnListener) Close() error {
|
||||||
|
l.once.Do(func() {
|
||||||
|
close(l.closed)
|
||||||
|
// Drain any remaining connections
|
||||||
|
go func() {
|
||||||
|
for conn := range l.connCh {
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
close(l.connCh)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *multiConnListener) Addr() net.Addr {
|
||||||
|
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *multiConnListener) addConn(conn net.Conn) error {
|
||||||
|
select {
|
||||||
|
case l.connCh <- conn:
|
||||||
|
return nil
|
||||||
|
case <-l.closed:
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("listener is closed")
|
||||||
|
default:
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("connection queue full")
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package client
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
26
tunnel/stream.go
Normal file
26
tunnel/stream.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Stream = (*streamImpl)(nil)
|
||||||
|
|
||||||
|
type Stream interface {
|
||||||
|
io.ReadWriteCloser
|
||||||
|
Source() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStream(conn net.Conn, source string) Stream {
|
||||||
|
return &streamImpl{conn, source}
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamImpl struct {
|
||||||
|
net.Conn
|
||||||
|
source string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *streamImpl) Source() string {
|
||||||
|
return s.source
|
||||||
|
}
|
37
tunnel/tcp_forwarder.go
Normal file
37
tunnel/tcp_forwarder.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"reichard.io/conduit/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTCPForwarder(target string, tunnelStore store.TunnelStore) Forwarder {
|
||||||
|
return &tcpConnBuilder{
|
||||||
|
target: target,
|
||||||
|
tunnelStore: tunnelStore,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tcpConnBuilder struct {
|
||||||
|
target string
|
||||||
|
tunnelStore store.TunnelStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *tcpConnBuilder) Type() ForwarderType {
|
||||||
|
return ForwarderTCP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *tcpConnBuilder) Initialize() (Stream, error) {
|
||||||
|
conn, err := net.Dial("tcp", l.target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &streamImpl{conn, l.target}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *tcpConnBuilder) Start(ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
262
tunnel/tunnel.go
262
tunnel/tunnel.go
@ -1,76 +1,88 @@
|
|||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"reichard.io/conduit/config"
|
||||||
"reichard.io/conduit/pkg/maps"
|
"reichard.io/conduit/pkg/maps"
|
||||||
"reichard.io/conduit/types"
|
"reichard.io/conduit/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnBuilder func() (conn io.ReadWriteCloser, err error)
|
// NewServerTunnel creates a new tunnel with name and websocket connection. The tunnel is
|
||||||
|
// generally instantiated after an upgrade request from the server.
|
||||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||||
return &Tunnel{
|
return &Tunnel{
|
||||||
name: name,
|
name: name,
|
||||||
streams: maps.New[string, io.ReadWriteCloser](),
|
streams: maps.New[string, Stream](),
|
||||||
wsConn: wsConn,
|
wsConn: wsConn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
|
// NewClientTunnel creates a new tunnel with the provided configuration and forwarder. A
|
||||||
// Get Target URL
|
// forwarder is effectively the protocol being forwarded. For example HTTP (Proxy), and TCP.
|
||||||
targetURL, err := url.Parse(target)
|
func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, error) {
|
||||||
|
// Parse Server URL
|
||||||
|
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Derive Conduit URL
|
// Parse Scheme
|
||||||
conduitURL := *serverURL
|
var wsScheme string
|
||||||
conduitURL.Host = name + "." + conduitURL.Host
|
switch serverURL.Scheme {
|
||||||
|
case "https":
|
||||||
// Get Connection Builder
|
wsScheme = "wss"
|
||||||
var connBuilder ConnBuilder
|
case "http":
|
||||||
switch targetURL.Scheme {
|
wsScheme = "ws"
|
||||||
case "http", "https":
|
|
||||||
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
|
|
||||||
connBuilder, err = HTTPConnectionBuilder(targetURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
|
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
|
||||||
connBuilder = func() (conn io.ReadWriteCloser, err error) {
|
|
||||||
return net.Dial("tcp", target)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create Tunnel Name
|
||||||
|
if cfg.TunnelName == "" {
|
||||||
|
cfg.TunnelName = generateTunnelName()
|
||||||
|
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect Server WS
|
||||||
|
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
|
||||||
|
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Tunnel{
|
return &Tunnel{
|
||||||
name: name,
|
name: cfg.TunnelName,
|
||||||
wsConn: wsConn,
|
wsConn: serverConn,
|
||||||
streams: maps.New[string, io.ReadWriteCloser](),
|
streams: maps.New[string, Stream](),
|
||||||
connBuilder: connBuilder,
|
forwarder: forwarder,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
|
ctx context.Context
|
||||||
name string
|
name string
|
||||||
wsConn *websocket.Conn
|
wsConn *websocket.Conn
|
||||||
streams *maps.Map[string, io.ReadWriteCloser]
|
streams *maps.Map[string, Stream]
|
||||||
connBuilder ConnBuilder
|
forwarder Forwarder
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) Start() {
|
func (t *Tunnel) Start(ctx context.Context) {
|
||||||
|
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||||
|
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||||
|
|
||||||
|
t.ctx = ctx
|
||||||
|
|
||||||
|
// Start Message Receiver
|
||||||
for {
|
for {
|
||||||
var msg types.Message
|
msg, err := t.readWSWithContext(ctx)
|
||||||
err := t.wsConn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -81,105 +93,57 @@ func (t *Tunnel) Start() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure Stream
|
// Get Stream
|
||||||
if err := t.initStreamConnection(msg.StreamID); err != nil {
|
stream, err := t.getStream(msg.StreamID)
|
||||||
log.WithError(err).Errorf("failed to initialize stream %s connection", t.name)
|
if err != nil {
|
||||||
|
if msg.Type != types.MessageTypeClose {
|
||||||
|
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Messages
|
// Handle Messages
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case types.MessageTypeClose:
|
case types.MessageTypeClose:
|
||||||
_ = t.CloseStream(msg.StreamID)
|
_ = t.closeStream(stream, msg.StreamID)
|
||||||
case types.MessageTypeData:
|
case types.MessageTypeData:
|
||||||
_ = t.WriteStream(msg.StreamID, msg.Data)
|
_, err = stream.Write(msg.Data)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) initStreamConnection(streamID string) error {
|
// Log Error
|
||||||
if t.connBuilder == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, found := t.streams.Get(streamID); found {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := t.connBuilder()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
log.WithError(err).Errorf("failed to handle message %s", msg.StreamID)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := t.AddStream(streamID, conn); err != nil {
|
func (t *Tunnel) readWSWithContext(ctx context.Context) (*types.Message, error) {
|
||||||
return err
|
type result struct {
|
||||||
|
msg *types.Message
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
go t.StartStream(streamID, "")
|
resultChan := make(chan result, 1)
|
||||||
return nil
|
go func() {
|
||||||
|
var msg types.Message
|
||||||
|
err := t.wsConn.ReadJSON(&msg)
|
||||||
|
resultChan <- result{&msg, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case result := <-resultChan:
|
||||||
|
return result.msg, result.err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
func (t *Tunnel) AddStream(stream Stream, streamID string) error {
|
||||||
if t.streams.HasKey(streamID) {
|
if t.streams.HasKey(streamID) {
|
||||||
return fmt.Errorf("stream %s already exists", streamID)
|
return fmt.Errorf("stream %s already exists", streamID)
|
||||||
}
|
}
|
||||||
t.streams.Set(streamID, conn)
|
log.Infof("tunnel %q initiated stream with %s", t.name, stream.Source())
|
||||||
return nil
|
t.streams.Set(streamID, stream)
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
|
|
||||||
// Get Stream
|
|
||||||
conn, found := t.streams.Get(streamID)
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("stream %s does not exist", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close Stream
|
|
||||||
defer func() {
|
|
||||||
_ = t.sendWS(&types.Message{
|
|
||||||
Type: types.MessageTypeClose,
|
|
||||||
StreamID: streamID,
|
|
||||||
SourceAddr: sourceAddr,
|
|
||||||
})
|
|
||||||
|
|
||||||
t.CloseStream(streamID)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start Stream
|
|
||||||
buffer := make([]byte, 4096)
|
|
||||||
for {
|
|
||||||
n, err := conn.Read(buffer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.sendWS(&types.Message{
|
|
||||||
Type: types.MessageTypeData,
|
|
||||||
StreamID: streamID,
|
|
||||||
Data: buffer[:n],
|
|
||||||
SourceAddr: sourceAddr,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
|
||||||
// Get Stream
|
|
||||||
conn, found := t.streams.Get(streamID)
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("stream %s does not exist", streamID)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := conn.Write(data)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tunnel) CloseStream(streamID string) error {
|
|
||||||
if conn, ok := t.streams.Get(streamID); ok {
|
|
||||||
t.streams.Delete(streamID)
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,6 +151,78 @@ func (t *Tunnel) Source() string {
|
|||||||
return t.wsConn.RemoteAddr().String()
|
return t.wsConn.RemoteAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
||||||
|
// Close Stream
|
||||||
|
defer t.closeStream(stream, streamID)
|
||||||
|
|
||||||
|
// Start Stream
|
||||||
|
for {
|
||||||
|
data, err := t.readStreamWithContext(t.ctx, stream)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.sendWS(&types.Message{
|
||||||
|
Type: types.MessageTypeData,
|
||||||
|
StreamID: streamID,
|
||||||
|
Data: data,
|
||||||
|
SourceAddr: stream.Source(),
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
||||||
|
log.Infof("tunnel %q closed stream with %s", t.name, stream.Source())
|
||||||
|
t.streams.Delete(streamID)
|
||||||
|
return stream.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) getStream(streamID string) (Stream, error) {
|
||||||
|
// Check Existing Stream
|
||||||
|
if stream, found := t.streams.Get(streamID); found {
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Forwarder
|
||||||
|
if t.forwarder == nil {
|
||||||
|
return nil, fmt.Errorf("stream %s does not exist", streamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize Forwarder & Add Stream
|
||||||
|
stream, err := t.forwarder.Initialize()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := t.AddStream(stream, streamID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go t.StartStream(stream, streamID)
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) readStreamWithContext(ctx context.Context, stream Stream) ([]byte, error) {
|
||||||
|
type result struct {
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
resultChan := make(chan result, 1)
|
||||||
|
go func() {
|
||||||
|
buffer := make([]byte, 4096)
|
||||||
|
n, err := stream.Read(buffer)
|
||||||
|
resultChan <- result{buffer[:n], err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case result := <-resultChan:
|
||||||
|
return result.data, result.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user