Compare commits
No commits in common. "main" and "0.0.1" have entirely different histories.
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1 @@
|
||||
cover.html
|
||||
.DS_Store
|
||||
|
159
client/client.go
Normal file
159
client/client.go
Normal file
@ -0,0 +1,159 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
func NewTunnel(cfg *config.ClientConfig) (*Tunnel, error) {
|
||||
// Parse Server URL
|
||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse Scheme
|
||||
var wsScheme string
|
||||
switch serverURL.Scheme {
|
||||
case "https":
|
||||
wsScheme = "wss"
|
||||
case "http":
|
||||
wsScheme = "ws"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
|
||||
}
|
||||
|
||||
// Create Tunnel Name
|
||||
if cfg.TunnelName == "" {
|
||||
cfg.TunnelName = generateTunnelName()
|
||||
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
|
||||
}
|
||||
|
||||
// Connect Server WS
|
||||
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
|
||||
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
|
||||
return &Tunnel{
|
||||
name: cfg.TunnelName,
|
||||
target: cfg.TunnelTarget,
|
||||
serverURL: serverURL,
|
||||
serverConn: serverConn,
|
||||
localConns: make(map[string]net.Conn),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
name string
|
||||
target string
|
||||
serverURL *url.URL
|
||||
|
||||
serverConn *websocket.Conn
|
||||
localConns map[string]net.Conn
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (t *Tunnel) Start() error {
|
||||
log.Infof("starting tunnel: %s.%s -> %s\n", t.name, t.serverURL.Hostname(), t.target)
|
||||
defer t.serverConn.Close()
|
||||
|
||||
// Handle Messages
|
||||
for {
|
||||
// Read Message
|
||||
var msg types.Message
|
||||
err := t.serverConn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
log.Errorf("error reading from tunnel: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case types.MessageTypeData:
|
||||
localConn, err := t.getLocalConn(msg.StreamID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get local connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Write data to local connection
|
||||
if _, err := localConn.Write(msg.Data); err != nil {
|
||||
log.Errorf("error writing to local connection: %v", err)
|
||||
localConn.Close()
|
||||
t.mu.Lock()
|
||||
delete(t.localConns, msg.StreamID)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
case types.MessageTypeClose:
|
||||
t.mu.Lock()
|
||||
if localConn, exists := t.localConns[msg.StreamID]; exists {
|
||||
localConn.Close()
|
||||
delete(t.localConns, msg.StreamID)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) getLocalConn(streamID string) (net.Conn, error) {
|
||||
// Get Cached Connection
|
||||
t.mu.RLock()
|
||||
localConn, exists := t.localConns[streamID]
|
||||
t.mu.RUnlock()
|
||||
if exists {
|
||||
return localConn, nil
|
||||
}
|
||||
|
||||
// Initiate Connection & Cache
|
||||
localConn, err := net.Dial("tcp", t.target)
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to %s: %v", t.target, err)
|
||||
return nil, err
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.localConns[streamID] = localConn
|
||||
t.mu.Unlock()
|
||||
|
||||
// Start Response Relay & Return Connection
|
||||
go t.startResponseRelay(streamID, localConn)
|
||||
return localConn, nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) startResponseRelay(streamID string, localConn net.Conn) {
|
||||
defer func() {
|
||||
t.mu.Lock()
|
||||
delete(t.localConns, streamID)
|
||||
t.mu.Unlock()
|
||||
localConn.Close()
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := localConn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
response := types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
StreamID: streamID,
|
||||
Data: buffer[:n],
|
||||
}
|
||||
|
||||
if err := t.serverConn.WriteJSON(response); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package tunnel
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
@ -1,8 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/server"
|
||||
|
||||
@ -21,11 +19,8 @@ var serveCmd = &cobra.Command{
|
||||
log.Fatal("failed to get server config:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create Server
|
||||
srv, err := server.NewServer(ctx, cfg)
|
||||
srv, err := server.NewServer(cfg)
|
||||
if err != nil {
|
||||
log.Fatal("failed to create server:", err)
|
||||
}
|
||||
|
@ -1,13 +1,10 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"reichard.io/conduit/client"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/store"
|
||||
"reichard.io/conduit/tunnel"
|
||||
)
|
||||
|
||||
var tunnelCmd = &cobra.Command{
|
||||
@ -20,22 +17,17 @@ var tunnelCmd = &cobra.Command{
|
||||
log.Fatal("failed to get client config:", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create Forwarder
|
||||
tunnelForwarder, err := tunnel.NewForwarder(cfg.TunnelTarget, store.NewTunnelStore(100))
|
||||
if err != nil {
|
||||
log.Fatal("failed to create tunnel forwarder:", err)
|
||||
}
|
||||
go tunnelForwarder.Start(ctx)
|
||||
|
||||
// Create Tunnel
|
||||
tunnel, err := tunnel.NewClientTunnel(cfg, tunnelForwarder)
|
||||
tunnel, err := client.NewTunnel(cfg)
|
||||
if err != nil {
|
||||
log.Fatal("failed to create tunnel:", err)
|
||||
}
|
||||
tunnel.Start(ctx)
|
||||
|
||||
// Start Tunnel
|
||||
log.Infof("creating TCP tunnel: %s -> %s", cfg.TunnelName, cfg.TunnelTarget)
|
||||
if err := tunnel.Start(); err != nil {
|
||||
log.Fatal("failed to start tunnel:", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -23,8 +23,6 @@ type ConfigDef struct {
|
||||
type BaseConfig struct {
|
||||
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
|
||||
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
||||
LogLevel string `json:"log_level" default:"info" description:"Log level"`
|
||||
LogFormat string `json:"log_format" default:"text" description:"Log format - text or json"`
|
||||
}
|
||||
|
||||
func (c *BaseConfig) Validate() error {
|
||||
@ -37,9 +35,6 @@ func (c *BaseConfig) Validate() error {
|
||||
if _, err := url.Parse(c.ServerAddress); err != nil {
|
||||
return fmt.Errorf("server is invalid: %w", err)
|
||||
}
|
||||
if c.LogFormat != "text" && c.LogFormat != "json" {
|
||||
return fmt.Errorf("log format must be 'text' or 'json'")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -73,13 +68,13 @@ func GetServerConfig(cmdFlags *pflag.FlagSet) (*ServerConfig, error) {
|
||||
}
|
||||
|
||||
cfg := &ServerConfig{
|
||||
BaseConfig: getBaseConfig(cfgValues),
|
||||
BaseConfig: BaseConfig{
|
||||
ServerAddress: cfgValues["server"],
|
||||
APIKey: cfgValues["api_key"],
|
||||
},
|
||||
BindAddress: cfgValues["bind"],
|
||||
}
|
||||
|
||||
// Initialize Logger
|
||||
initLogger(cfg.BaseConfig)
|
||||
|
||||
return cfg, cfg.Validate()
|
||||
}
|
||||
|
||||
@ -92,14 +87,14 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
|
||||
}
|
||||
|
||||
cfg := &ClientConfig{
|
||||
BaseConfig: getBaseConfig(cfgValues),
|
||||
BaseConfig: BaseConfig{
|
||||
ServerAddress: cfgValues["server"],
|
||||
APIKey: cfgValues["api_key"],
|
||||
},
|
||||
TunnelName: cfgValues["name"],
|
||||
TunnelTarget: cfgValues["target"],
|
||||
}
|
||||
|
||||
// Initialize Logger
|
||||
initLogger(cfg.BaseConfig)
|
||||
|
||||
return cfg, cfg.Validate()
|
||||
}
|
||||
|
||||
@ -113,19 +108,10 @@ func GetVersion() string {
|
||||
return version
|
||||
}
|
||||
|
||||
func getBaseConfig(cfgValues map[string]string) BaseConfig {
|
||||
return BaseConfig{
|
||||
ServerAddress: cfgValues["server"],
|
||||
APIKey: cfgValues["api_key"],
|
||||
LogLevel: cfgValues["log_level"],
|
||||
LogFormat: cfgValues["log_format"],
|
||||
}
|
||||
}
|
||||
|
||||
func getConfigValue(cmdFlags *pflag.FlagSet, def ConfigDef) string {
|
||||
// 1. Get Flags First
|
||||
if cmdFlags != nil {
|
||||
if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" && val != def.Default {
|
||||
if val, err := cmdFlags.GetString(def.Key); err == nil && val != "" {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
@ -1,61 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func initLogger(cfg BaseConfig) {
|
||||
// Parse Log Level
|
||||
logLevel, err := log.ParseLevel(cfg.LogLevel)
|
||||
if err != nil {
|
||||
logLevel = log.InfoLevel
|
||||
}
|
||||
log.SetLevel(logLevel)
|
||||
|
||||
// Create Log Formatter
|
||||
var logFormatter log.Formatter
|
||||
switch cfg.LogFormat {
|
||||
case "json":
|
||||
log.SetReportCaller(true)
|
||||
logFormatter = &log.JSONFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
CallerPrettyfier: prettyCaller,
|
||||
}
|
||||
case "text":
|
||||
logFormatter = &log.TextFormatter{
|
||||
TimestampFormat: time.RFC3339,
|
||||
FullTimestamp: true,
|
||||
}
|
||||
}
|
||||
|
||||
log.SetFormatter(&utcFormatter{logFormatter})
|
||||
}
|
||||
|
||||
func prettyCaller(f *runtime.Frame) (function string, file string) {
|
||||
purgePrefix := "reichard.io/conduit/"
|
||||
|
||||
pathName := strings.Replace(f.Func.Name(), purgePrefix, "", 1)
|
||||
parts := strings.Split(pathName, ".")
|
||||
|
||||
filepath, line := f.Func.FileLine(f.PC)
|
||||
splitFilePath := strings.Split(filepath, "/")
|
||||
|
||||
fileName := fmt.Sprintf("%s/%s@%d", parts[0], splitFilePath[len(splitFilePath)-1], line)
|
||||
functionName := strings.Replace(pathName, parts[0]+".", "", 1)
|
||||
|
||||
return functionName, fileName
|
||||
}
|
||||
|
||||
type utcFormatter struct {
|
||||
log.Formatter
|
||||
}
|
||||
|
||||
func (cf utcFormatter) Format(e *log.Entry) ([]byte, error) {
|
||||
e.Time = e.Time.UTC()
|
||||
return cf.Formatter.Format(e)
|
||||
}
|
1
go.mod
1
go.mod
@ -3,7 +3,6 @@ module reichard.io/conduit
|
||||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -1,8 +1,6 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
|
@ -1,51 +0,0 @@
|
||||
package maps
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Map[K comparable, V any] struct {
|
||||
items map[K]V
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{items: make(map[K]V)}
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
v, ok := m.items[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.items[key] = value
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.items, key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) HasKey(key K) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
for k, v := range m.items {
|
||||
if !yield(k, v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var _ io.ReadWriteCloser = (*reconstructedConn)(nil)
|
||||
|
||||
// reconstructedConn wraps a net.Conn and overrides Read to handle captured data.
|
||||
type reconstructedConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
// Read reads from the reconstructed reader (captured data + original conn).
|
||||
func (rc *reconstructedConn) Read(p []byte) (n int, err error) {
|
||||
return rc.reader.Read(p)
|
||||
}
|
||||
|
||||
// newReconstructedConn creates a reconstructed connection that replays captured data
|
||||
// before reading from the original connection.
|
||||
func newReconstructedConn(conn net.Conn, capturedData *bytes.Buffer) net.Conn {
|
||||
allReader := io.MultiReader(capturedData, conn)
|
||||
return &reconstructedConn{
|
||||
Conn: conn,
|
||||
reader: allReader,
|
||||
}
|
||||
}
|
206
server/server.go
206
server/server.go
@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -12,13 +11,13 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/tunnel"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
type InfoResponse struct {
|
||||
@ -31,16 +30,22 @@ type TunnelInfo struct {
|
||||
Target string `json:"target"`
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
host string
|
||||
cfg *config.ServerConfig
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||
type TunnelConnection struct {
|
||||
*websocket.Conn
|
||||
name string
|
||||
streams map[string]chan []byte
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
||||
type Server struct {
|
||||
host string
|
||||
cfg *config.ServerConfig
|
||||
mu sync.RWMutex
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
tunnels map[string]*TunnelConnection
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server address: %v", err)
|
||||
@ -49,10 +54,9 @@ func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
||||
}
|
||||
|
||||
return &Server{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
host: serverURL.Host,
|
||||
tunnels: maps.New[string, *tunnel.Tunnel](),
|
||||
tunnels: make(map[string]*TunnelConnection),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
@ -75,7 +79,7 @@ func (s *Server) Start() error {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error accepting connection")
|
||||
log.Printf("error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -86,12 +90,14 @@ func (s *Server) Start() error {
|
||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||
// Get Tunnels
|
||||
var allTunnels []TunnelInfo
|
||||
for t, c := range s.tunnels.Entries() {
|
||||
s.mu.RLock()
|
||||
for t, c := range s.tunnels {
|
||||
allTunnels = append(allTunnels, TunnelInfo{
|
||||
Name: t,
|
||||
Target: c.Source(),
|
||||
Target: c.RemoteAddr().String(),
|
||||
})
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Create Response
|
||||
d, err := json.MarshalIndent(InfoResponse{
|
||||
@ -99,17 +105,72 @@ func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||
Version: config.GetVersion(),
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
log.WithError(err).Error("failed to marshal info")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Send Response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(d)
|
||||
}
|
||||
|
||||
func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConnection, dataReader io.Reader) {
|
||||
defer clientConn.Close()
|
||||
|
||||
// Create Identifiers
|
||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||
responseChan := make(chan []byte, 100)
|
||||
|
||||
// Register Stream
|
||||
s.mu.Lock()
|
||||
if tunnelConn.streams == nil {
|
||||
tunnelConn.streams = make(map[string]chan []byte)
|
||||
}
|
||||
tunnelConn.streams[streamID] = responseChan
|
||||
s.mu.Unlock()
|
||||
|
||||
// Clean Up
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(tunnelConn.streams, streamID)
|
||||
close(responseChan)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Send Close
|
||||
closeMsg := types.Message{
|
||||
Type: types.MessageTypeClose,
|
||||
StreamID: streamID,
|
||||
}
|
||||
_ = tunnelConn.WriteJSON(closeMsg)
|
||||
}()
|
||||
|
||||
// Read & Send Chunks
|
||||
go func() {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := dataReader.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := tunnelConn.WriteJSON(types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
StreamID: streamID,
|
||||
Data: buffer[:n],
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return Response Data
|
||||
for data := range responseChan {
|
||||
if _, err := clientConn.Write(data); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
@ -122,7 +183,7 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
bufReader := bufio.NewReader(teeReader)
|
||||
|
||||
// Create HTTP Request & Writer
|
||||
w := &rawHTTPResponseWriter{conn: conn}
|
||||
w := &connResponseWriter{conn: conn}
|
||||
r, err := http.ReadRequest(bufReader)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@ -138,49 +199,30 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
}
|
||||
|
||||
// Extract Subdomain
|
||||
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||
if strings.Count(tunnelName, ".") != 0 {
|
||||
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||
if strings.Count(subdomain, ".") != 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
||||
return
|
||||
}
|
||||
|
||||
// Get True Host
|
||||
remoteHost := conn.RemoteAddr().String()
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
remoteHost = xff
|
||||
}
|
||||
r.RemoteAddr = remoteHost
|
||||
|
||||
// Handle Control Endpoints
|
||||
if tunnelName == "" {
|
||||
if subdomain == "" {
|
||||
s.handleAsHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Tunnels
|
||||
conduitTunnel, exists := s.tunnels.Get(tunnelName)
|
||||
if !exists {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
|
||||
return
|
||||
s.mu.RLock()
|
||||
tunnelConn, exists := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
if exists {
|
||||
log.Infof("relaying %s to tunnel", subdomain)
|
||||
|
||||
// Reconstruct Data & Proxy Connection
|
||||
allReader := io.MultiReader(&capturedData, r.Body)
|
||||
s.proxyRawConnection(conn, tunnelConn, allReader)
|
||||
}
|
||||
|
||||
// Create Stream
|
||||
reconstructedConn := newReconstructedConn(conn, &capturedData)
|
||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr)
|
||||
|
||||
// Add Stream
|
||||
if err := conduitTunnel.AddStream(tunnelStream, streamID); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = fmt.Fprintf(w, "failed to add stream: %v", err)
|
||||
log.WithError(err).Error("failed to add stream")
|
||||
return
|
||||
}
|
||||
|
||||
// Start Stream
|
||||
conduitTunnel.StartStream(tunnelStream, streamID)
|
||||
}
|
||||
|
||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@ -203,6 +245,40 @@ func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
||||
for {
|
||||
var msg types.Message
|
||||
err := tunnel.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if msg.StreamID == "" {
|
||||
log.Infof("tunnel %s missing streamID", tunnel.name)
|
||||
continue
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case types.MessageTypeClose:
|
||||
return
|
||||
case types.MessageTypeData:
|
||||
s.mu.RLock()
|
||||
streamChan, exists := tunnel.streams[msg.StreamID]
|
||||
if !exists {
|
||||
log.Infof("stream %s does not exist", msg.StreamID)
|
||||
s.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case streamChan <- msg.Data:
|
||||
case <-time.After(time.Second):
|
||||
log.Warnf("stream %s channel full, dropping data", msg.StreamID)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
// Get Tunnel Name
|
||||
tunnelName := r.URL.Query().Get("tunnelName")
|
||||
@ -213,7 +289,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Validate Unique
|
||||
if _, exists := s.tunnels.Get(tunnelName); exists {
|
||||
if _, exists := s.tunnels[tunnelName]; exists {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
_, _ = w.Write([]byte("Tunnel already registered"))
|
||||
return
|
||||
@ -226,14 +302,26 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create Tunnel
|
||||
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
||||
s.tunnels.Set(tunnelName, conduitTunnel)
|
||||
// Create & Cache TunnelConnection
|
||||
tunnel := &TunnelConnection{
|
||||
Conn: wsConn,
|
||||
name: tunnelName,
|
||||
streams: make(map[string]chan []byte),
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.tunnels[tunnelName] = tunnel
|
||||
s.mu.Unlock()
|
||||
log.Infof("tunnel established: %s", tunnelName)
|
||||
|
||||
// Start Tunnel - This is blocking
|
||||
conduitTunnel.Start(s.ctx)
|
||||
// Keep connection alive and handle cleanup
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.tunnels, tunnelName)
|
||||
s.mu.Unlock()
|
||||
_ = wsConn.Close()
|
||||
log.Infof("tunnel closed: %s", tunnelName)
|
||||
}()
|
||||
|
||||
// Cleanup Tunnel
|
||||
s.tunnels.Delete(tunnelName)
|
||||
_ = wsConn.Close()
|
||||
// Handle tunnel messages
|
||||
s.handleTunnelMessages(tunnel)
|
||||
}
|
||||
|
@ -7,25 +7,25 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var _ http.ResponseWriter = (*rawHTTPResponseWriter)(nil)
|
||||
var _ http.ResponseWriter = (*connResponseWriter)(nil)
|
||||
|
||||
type rawHTTPResponseWriter struct {
|
||||
type connResponseWriter struct {
|
||||
conn net.Conn
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (f *rawHTTPResponseWriter) Header() http.Header {
|
||||
func (f *connResponseWriter) Header() http.Header {
|
||||
if f.header == nil {
|
||||
f.header = make(http.Header)
|
||||
}
|
||||
return f.header
|
||||
}
|
||||
|
||||
func (f *rawHTTPResponseWriter) Write(data []byte) (int, error) {
|
||||
func (f *connResponseWriter) Write(data []byte) (int, error) {
|
||||
return f.conn.Write(data)
|
||||
}
|
||||
|
||||
func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) {
|
||||
func (f *connResponseWriter) WriteHeader(statusCode int) {
|
||||
// Write Status
|
||||
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
|
||||
_, _ = f.conn.Write([]byte(status))
|
||||
@ -41,7 +41,7 @@ func (f *rawHTTPResponseWriter) WriteHeader(statusCode int) {
|
||||
_, _ = f.conn.Write([]byte("\r\n"))
|
||||
}
|
||||
|
||||
func (f *rawHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
func (f *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
// Return Raw Connection & ReadWriter
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
||||
return f.conn, rw, nil
|
@ -1,18 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
var recordIDKey = contextKey{}
|
||||
|
||||
func withRecord(ctx context.Context, rec *TunnelRecord) context.Context {
|
||||
return context.WithValue(ctx, recordIDKey, rec)
|
||||
}
|
||||
|
||||
func getRecord(ctx context.Context) (*TunnelRecord, bool) {
|
||||
id, ok := ctx.Value(recordIDKey).(*TunnelRecord)
|
||||
return id, ok
|
||||
}
|
196
store/store.go
196
store/store.go
@ -1,196 +0,0 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 100
|
||||
maxQueueSize = 100
|
||||
)
|
||||
|
||||
var ErrRecordNotFound = errors.New("record not found")
|
||||
|
||||
type TunnelStore interface {
|
||||
Get(before time.Time, count int) (results []*TunnelRecord, more bool)
|
||||
RecordTCP()
|
||||
RecordRequest(req *http.Request)
|
||||
RecordResponse(resp *http.Response) error
|
||||
}
|
||||
|
||||
type TunnelRecord struct {
|
||||
ID uuid.UUID
|
||||
Time time.Time
|
||||
URL *url.URL
|
||||
Method string
|
||||
Status int
|
||||
|
||||
RequestHeaders http.Header
|
||||
RequestBodyType string
|
||||
RequestBody []byte
|
||||
|
||||
ResponseHeaders http.Header
|
||||
ResponseBodyType string
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
func NewTunnelStore(queueSize int) TunnelStore {
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
} else if queueSize > maxQueueSize {
|
||||
queueSize = maxQueueSize
|
||||
}
|
||||
|
||||
return &tunnelStoreImpl{queueSize: queueSize}
|
||||
}
|
||||
|
||||
type tunnelStoreImpl struct {
|
||||
orderedRecords []*TunnelRecord
|
||||
queueSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *tunnelStoreImpl) Get(before time.Time, count int) ([]*TunnelRecord, bool) {
|
||||
// Find First
|
||||
start := -1
|
||||
for i, r := range s.orderedRecords {
|
||||
if r.Time.Before(before) {
|
||||
start = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Not Found
|
||||
if start == -1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Subslice Records
|
||||
end := min(start+count, len(s.orderedRecords))
|
||||
results := s.orderedRecords[start:end]
|
||||
more := end < len(s.orderedRecords)
|
||||
|
||||
return results, more
|
||||
}
|
||||
|
||||
func (s *tunnelStoreImpl) RecordRequest(req *http.Request) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
url := *req.URL
|
||||
rec := &TunnelRecord{
|
||||
ID: uuid.New(),
|
||||
Time: time.Now(),
|
||||
URL: &url,
|
||||
Method: req.Method,
|
||||
RequestHeaders: req.Header,
|
||||
RequestBodyType: req.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
if bodyData, err := getRequestBody(req); err == nil {
|
||||
rec.RequestBody = bodyData
|
||||
}
|
||||
|
||||
// Add Record & Truncate
|
||||
s.orderedRecords = append(s.orderedRecords, rec)
|
||||
if len(s.orderedRecords) > s.queueSize {
|
||||
s.orderedRecords = s.orderedRecords[len(s.orderedRecords)-s.queueSize:]
|
||||
}
|
||||
|
||||
*req = *req.WithContext(withRecord(req.Context(), rec))
|
||||
}
|
||||
|
||||
func (s *tunnelStoreImpl) RecordResponse(resp *http.Response) error {
|
||||
rec, found := getRecord(resp.Request.Context())
|
||||
if !found {
|
||||
return ErrRecordNotFound
|
||||
}
|
||||
|
||||
rec.ResponseHeaders = resp.Header
|
||||
rec.ResponseBodyType = resp.Header.Get("Content-Type")
|
||||
|
||||
if bodyData, err := getResponseBody(resp); err == nil {
|
||||
rec.ResponseBody = bodyData
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tunnelStoreImpl) RecordTCP() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// TODO
|
||||
}
|
||||
|
||||
func getRequestBody(req *http.Request) ([]byte, error) {
|
||||
if req.ContentLength == 0 || req.Body == nil || req.Body == http.NoBody {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !isTextContentType(req.Header.Get("Content-Type")) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Read Body
|
||||
bodyBytes, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Restore Body
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func getResponseBody(resp *http.Response) ([]byte, error) {
|
||||
if resp.ContentLength == 0 || resp.Body == nil || resp.Body == http.NoBody {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !isTextContentType(resp.Header.Get("Content-Type")) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Read Body
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Restore Body
|
||||
resp.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func isTextContentType(contentType string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(mediaType, "text/") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch mediaType {
|
||||
case "application/json":
|
||||
return true
|
||||
case "application/xml":
|
||||
return true
|
||||
case "application/x-www-form-urlencoded":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
type ForwarderType int
|
||||
|
||||
const (
|
||||
ForwarderTCP ForwarderType = iota
|
||||
ForwarderHTTP
|
||||
)
|
||||
|
||||
type Forwarder interface {
|
||||
Type() ForwarderType
|
||||
Initialize() (Stream, error)
|
||||
Start(context.Context) error
|
||||
}
|
||||
|
||||
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||
// Get Target URL
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get Connection Builder
|
||||
var forwarder Forwarder
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
forwarder, err = newHTTPForwarder(targetURL, tunnelStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
forwarder = newTCPForwarder(target, tunnelStore)
|
||||
}
|
||||
|
||||
return forwarder, nil
|
||||
}
|
@ -1,132 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
func newHTTPForwarder(targetURL *url.URL, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||
return &httpConnBuilder{
|
||||
multiConnListener: newMultiConnListener(),
|
||||
tunnelStore: tunnelStore,
|
||||
targetURL: targetURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type httpConnBuilder struct {
|
||||
multiConnListener *multiConnListener
|
||||
tunnelStore store.TunnelStore
|
||||
targetURL *url.URL
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Type() ForwarderType {
|
||||
return ForwarderHTTP
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Start(ctx context.Context) error {
|
||||
// Create Reverse Proxy Server
|
||||
server := &http.Server{
|
||||
Handler: &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.Host = c.targetURL.Host
|
||||
req.URL.Host = c.targetURL.Host
|
||||
req.URL.Scheme = c.targetURL.Scheme
|
||||
c.tunnelStore.RecordRequest(req)
|
||||
},
|
||||
ModifyResponse: c.tunnelStore.RecordResponse,
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Context & Cleanup
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
server.Shutdown(ctx)
|
||||
c.multiConnListener.Close()
|
||||
}()
|
||||
|
||||
// Start HTTP Proxy
|
||||
if err := server.Serve(c.multiConnListener); err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Initialize() (Stream, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
if err := c.multiConnListener.addConn(serverConn); err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &streamImpl{clientConn, c.targetURL.String()}, nil
|
||||
}
|
||||
|
||||
type multiConnListener struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newMultiConnListener() *multiConnListener {
|
||||
return &multiConnListener{
|
||||
connCh: make(chan net.Conn, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.connCh:
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
// Drain any remaining connections
|
||||
go func() {
|
||||
for conn := range l.connCh {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
close(l.connCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) addConn(conn net.Conn) error {
|
||||
select {
|
||||
case l.connCh <- conn:
|
||||
return nil
|
||||
case <-l.closed:
|
||||
conn.Close()
|
||||
return fmt.Errorf("listener is closed")
|
||||
default:
|
||||
conn.Close()
|
||||
return fmt.Errorf("connection queue full")
|
||||
}
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var _ Stream = (*streamImpl)(nil)
|
||||
|
||||
type Stream interface {
|
||||
io.ReadWriteCloser
|
||||
Source() string
|
||||
}
|
||||
|
||||
func NewStream(conn net.Conn, source string) Stream {
|
||||
return &streamImpl{conn, source}
|
||||
}
|
||||
|
||||
type streamImpl struct {
|
||||
net.Conn
|
||||
source string
|
||||
}
|
||||
|
||||
func (s *streamImpl) Source() string {
|
||||
return s.source
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
func newTCPForwarder(target string, tunnelStore store.TunnelStore) Forwarder {
|
||||
return &tcpConnBuilder{
|
||||
target: target,
|
||||
tunnelStore: tunnelStore,
|
||||
}
|
||||
}
|
||||
|
||||
type tcpConnBuilder struct {
|
||||
target string
|
||||
tunnelStore store.TunnelStore
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Type() ForwarderType {
|
||||
return ForwarderTCP
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Initialize() (Stream, error) {
|
||||
conn, err := net.Dial("tcp", l.target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &streamImpl{conn, l.target}, nil
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
230
tunnel/tunnel.go
230
tunnel/tunnel.go
@ -1,230 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
// NewServerTunnel creates a new tunnel with name and websocket connection. The tunnel is
|
||||
// generally instantiated after an upgrade request from the server.
|
||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
streams: maps.New[string, Stream](),
|
||||
wsConn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientTunnel creates a new tunnel with the provided configuration and forwarder. A
|
||||
// forwarder is effectively the protocol being forwarded. For example HTTP (Proxy), and TCP.
|
||||
func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, error) {
|
||||
// Parse Server URL
|
||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse Scheme
|
||||
var wsScheme string
|
||||
switch serverURL.Scheme {
|
||||
case "https":
|
||||
wsScheme = "wss"
|
||||
case "http":
|
||||
wsScheme = "ws"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
|
||||
}
|
||||
|
||||
// Create Tunnel Name
|
||||
if cfg.TunnelName == "" {
|
||||
cfg.TunnelName = generateTunnelName()
|
||||
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
|
||||
}
|
||||
|
||||
// Connect Server WS
|
||||
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
|
||||
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
|
||||
return &Tunnel{
|
||||
name: cfg.TunnelName,
|
||||
wsConn: serverConn,
|
||||
streams: maps.New[string, Stream](),
|
||||
forwarder: forwarder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams *maps.Map[string, Stream]
|
||||
forwarder Forwarder
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *Tunnel) Start(ctx context.Context) {
|
||||
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
|
||||
t.ctx = ctx
|
||||
|
||||
// Start Message Receiver
|
||||
for {
|
||||
msg, err := t.readWSWithContext(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Stream
|
||||
if msg.StreamID == "" {
|
||||
log.Warnf("tunnel %s missing streamID", t.name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get Stream
|
||||
stream, err := t.getStream(msg.StreamID)
|
||||
if err != nil {
|
||||
if msg.Type != types.MessageTypeClose {
|
||||
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle Messages
|
||||
switch msg.Type {
|
||||
case types.MessageTypeClose:
|
||||
_ = t.closeStream(stream, msg.StreamID)
|
||||
case types.MessageTypeData:
|
||||
_, err = stream.Write(msg.Data)
|
||||
}
|
||||
|
||||
// Log Error
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to handle message %s", msg.StreamID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) readWSWithContext(ctx context.Context) (*types.Message, error) {
|
||||
type result struct {
|
||||
msg *types.Message
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
var msg types.Message
|
||||
err := t.wsConn.ReadJSON(&msg)
|
||||
resultChan <- result{&msg, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultChan:
|
||||
return result.msg, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) AddStream(stream Stream, streamID string) error {
|
||||
if t.streams.HasKey(streamID) {
|
||||
return fmt.Errorf("stream %s already exists", streamID)
|
||||
}
|
||||
log.Infof("tunnel %q initiated stream with %s", t.name, stream.Source())
|
||||
t.streams.Set(streamID, stream)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) Source() string {
|
||||
return t.wsConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
||||
// Close Stream
|
||||
defer t.closeStream(stream, streamID)
|
||||
|
||||
// Start Stream
|
||||
for {
|
||||
data, err := t.readStreamWithContext(t.ctx, stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
StreamID: streamID,
|
||||
Data: data,
|
||||
SourceAddr: stream.Source(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
||||
log.Infof("tunnel %q closed stream with %s", t.name, stream.Source())
|
||||
t.streams.Delete(streamID)
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID string) (Stream, error) {
|
||||
// Check Existing Stream
|
||||
if stream, found := t.streams.Get(streamID); found {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Check Forwarder
|
||||
if t.forwarder == nil {
|
||||
return nil, fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
|
||||
// Initialize Forwarder & Add Stream
|
||||
stream, err := t.forwarder.Initialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := t.AddStream(stream, streamID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go t.StartStream(stream, streamID)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) readStreamWithContext(ctx context.Context, stream Stream) ([]byte, error) {
|
||||
type result struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
buffer := make([]byte, 4096)
|
||||
n, err := stream.Read(buffer)
|
||||
resultChan <- result{buffer[:n], err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultChan:
|
||||
return result.data, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.wsConn.WriteJSON(msg)
|
||||
}
|
@ -8,8 +8,7 @@ const (
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
StreamID string `json:"stream_id"`
|
||||
SourceAddr string `json:"source_addr"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
Type MessageType `json:"type"`
|
||||
StreamID string `json:"stream_id"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user