Compare commits

..

No commits in common. "20c1388cf46e1f56f813ac419827fb30b81ec670" and "de23b3e815a691f07c13ae28909c55ddd9f9d174" have entirely different histories.

5 changed files with 67 additions and 111 deletions

View File

@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
return nil, fmt.Errorf("failed to connect: %v", err) return nil, fmt.Errorf("failed to connect: %v", err)
} }
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn) return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn)
} }

View File

@ -1,51 +0,0 @@
package maps
import (
"iter"
"sync"
)
type Map[K comparable, V any] struct {
items map[K]V
mu sync.RWMutex
}
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{items: make(map[K]V)}
}
func (m *Map[K, V]) Get(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
v, ok := m.items[key]
return v, ok
}
func (m *Map[K, V]) Set(key K, value V) {
m.mu.Lock()
defer m.mu.Unlock()
m.items[key] = value
}
func (m *Map[K, V]) Delete(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
}
func (m *Map[K, V]) HasKey(key K) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.items[key]
return ok
}
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
for k, v := range m.items {
if !yield(k, v) {
return
}
}
}
}

View File

@ -11,12 +11,12 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/conduit/config" "reichard.io/conduit/config"
"reichard.io/conduit/pkg/maps"
"reichard.io/conduit/tunnel" "reichard.io/conduit/tunnel"
) )
@ -33,9 +33,10 @@ type TunnelInfo struct {
type Server struct { type Server struct {
host string host string
cfg *config.ServerConfig cfg *config.ServerConfig
mu sync.RWMutex
upgrader websocket.Upgrader upgrader websocket.Upgrader
tunnels *maps.Map[string, *tunnel.Tunnel] tunnels map[string]*tunnel.Tunnel
} }
func NewServer(cfg *config.ServerConfig) (*Server, error) { func NewServer(cfg *config.ServerConfig) (*Server, error) {
@ -49,7 +50,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) {
return &Server{ return &Server{
cfg: cfg, cfg: cfg,
host: serverURL.Host, host: serverURL.Host,
tunnels: maps.New[string, *tunnel.Tunnel](), tunnels: make(map[string]*tunnel.Tunnel),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
@ -83,12 +84,14 @@ func (s *Server) Start() error {
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
// Get Tunnels // Get Tunnels
var allTunnels []TunnelInfo var allTunnels []TunnelInfo
for t, c := range s.tunnels.Entries() { s.mu.RLock()
for t, c := range s.tunnels {
allTunnels = append(allTunnels, TunnelInfo{ allTunnels = append(allTunnels, TunnelInfo{
Name: t, Name: t,
Target: c.Source(), Target: c.Source(),
}) })
} }
s.mu.RUnlock()
// Create Response // Create Response
d, err := json.MarshalIndent(InfoResponse{ d, err := json.MarshalIndent(InfoResponse{
@ -135,31 +138,26 @@ func (s *Server) handleRawConnection(conn net.Conn) {
} }
// Extract Subdomain // Extract Subdomain
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".") subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
if strings.Count(tunnelName, ".") != 0 { if strings.Count(subdomain, ".") != 0 {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host) _, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
return return
} }
// Get True Host
remoteHost := conn.RemoteAddr().String()
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
remoteHost = xff
}
r.RemoteAddr = remoteHost
// Handle Control Endpoints // Handle Control Endpoints
if tunnelName == "" { if subdomain == "" {
s.handleAsHTTP(w, r) s.handleAsHTTP(w, r)
return return
} }
// Handle Tunnels // Handle Tunnels
conduitTunnel, exists := s.tunnels.Get(tunnelName) s.mu.RLock()
conduitTunnel, exists := s.tunnels[subdomain]
s.mu.RUnlock()
if !exists { if !exists {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName) _, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain)
return return
} }
@ -172,8 +170,8 @@ func (s *Server) handleRawConnection(conn net.Conn) {
return return
} }
log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr) log.Infof("relaying %s to tunnel", subdomain)
_ = conduitTunnel.StartStream(streamID, r.RemoteAddr) _ = conduitTunnel.StartStream(streamID)
} }
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
@ -206,7 +204,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
} }
// Validate Unique // Validate Unique
if _, exists := s.tunnels.Get(tunnelName); exists { if _, exists := s.tunnels[tunnelName]; exists {
w.WriteHeader(http.StatusConflict) w.WriteHeader(http.StatusConflict)
_, _ = w.Write([]byte("Tunnel already registered")) _, _ = w.Write([]byte("Tunnel already registered"))
return return
@ -221,14 +219,18 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Create Tunnel // Create Tunnel
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
s.tunnels.Set(tunnelName, conduitTunnel) s.mu.Lock()
log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr) s.tunnels[tunnelName] = conduitTunnel
s.mu.Unlock()
log.Infof("tunnel established: %s", tunnelName)
// Start Tunnel - This is blocking // Start Tunnel - This is blocking
conduitTunnel.Start() conduitTunnel.Start()
// Cleanup Tunnel // Cleanup Tunnel
s.tunnels.Delete(tunnelName) s.mu.Lock()
delete(s.tunnels, tunnelName)
s.mu.Unlock()
_ = wsConn.Close() _ = wsConn.Close()
log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr) log.Infof("tunnel closed: %s", tunnelName)
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reichard.io/conduit/pkg/maps"
"reichard.io/conduit/types" "reichard.io/conduit/types"
) )
@ -18,33 +17,27 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error)
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
return &Tunnel{ return &Tunnel{
name: name, name: name,
streams: maps.New[string, io.ReadWriteCloser](),
wsConn: wsConn, wsConn: wsConn,
streams: make(map[string]io.ReadWriteCloser),
} }
} }
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) { func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) {
// Get Target URL
targetURL, err := url.Parse(target) targetURL, err := url.Parse(target)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Derive Conduit URL
conduitURL := *serverURL
conduitURL.Host = name + "." + conduitURL.Host
// Get Connection Builder
var connBuilder ConnBuilder var connBuilder ConnBuilder
switch targetURL.Scheme { switch targetURL.Scheme {
case "http", "https": case "http", "https":
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target) log.Infof("creating HTTP tunnel: %s -> %s", name, target)
connBuilder, err = HTTPConnectionBuilder(targetURL) connBuilder, err = HTTPConnectionBuilder(targetURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
default: default:
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target) log.Infof("creating TCP tunnel: %s -> %s", name, target)
connBuilder = func() (conn io.ReadWriteCloser, err error) { connBuilder = func() (conn io.ReadWriteCloser, err error) {
return net.Dial("tcp", target) return net.Dial("tcp", target)
} }
@ -53,7 +46,7 @@ func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.
return &Tunnel{ return &Tunnel{
name: name, name: name,
wsConn: wsConn, wsConn: wsConn,
streams: maps.New[string, io.ReadWriteCloser](), streams: make(map[string]io.ReadWriteCloser),
connBuilder: connBuilder, connBuilder: connBuilder,
}, nil }, nil
} }
@ -61,10 +54,10 @@ func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.
type Tunnel struct { type Tunnel struct {
name string name string
wsConn *websocket.Conn wsConn *websocket.Conn
streams *maps.Map[string, io.ReadWriteCloser] streams map[string]io.ReadWriteCloser
connBuilder ConnBuilder connBuilder ConnBuilder
mu sync.Mutex wsMu, streamsMu sync.Mutex
} }
func (t *Tunnel) Start() { func (t *Tunnel) Start() {
@ -102,7 +95,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
return nil return nil
} }
if _, found := t.streams.Get(streamID); found { if _, found := t.getStream(streamID); found {
return nil return nil
} }
@ -115,21 +108,24 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
return err return err
} }
go t.StartStream(streamID, "") go t.StartStream(streamID)
return nil return nil
} }
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error { func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
if t.streams.HasKey(streamID) { t.streamsMu.Lock()
defer t.streamsMu.Unlock()
if _, found := t.streams[streamID]; found {
return fmt.Errorf("stream %s already exists", streamID) return fmt.Errorf("stream %s already exists", streamID)
} }
t.streams.Set(streamID, conn) t.streams[streamID] = conn
return nil return nil
} }
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error { func (t *Tunnel) StartStream(streamID string) error {
// Get Stream // Get Stream
conn, found := t.streams.Get(streamID) conn, found := t.getStream(streamID)
if !found { if !found {
return fmt.Errorf("stream %s does not exist", streamID) return fmt.Errorf("stream %s does not exist", streamID)
} }
@ -139,7 +135,6 @@ func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
_ = t.sendWS(&types.Message{ _ = t.sendWS(&types.Message{
Type: types.MessageTypeClose, Type: types.MessageTypeClose,
StreamID: streamID, StreamID: streamID,
SourceAddr: sourceAddr,
}) })
t.CloseStream(streamID) t.CloseStream(streamID)
@ -155,9 +150,8 @@ func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
if err := t.sendWS(&types.Message{ if err := t.sendWS(&types.Message{
Type: types.MessageTypeData, Type: types.MessageTypeData,
StreamID: streamID,
Data: buffer[:n], Data: buffer[:n],
SourceAddr: sourceAddr, StreamID: streamID,
}); err != nil { }); err != nil {
return err return err
} }
@ -166,7 +160,7 @@ func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
func (t *Tunnel) WriteStream(streamID string, data []byte) error { func (t *Tunnel) WriteStream(streamID string, data []byte) error {
// Get Stream // Get Stream
conn, found := t.streams.Get(streamID) conn, found := t.getStream(streamID)
if !found { if !found {
return fmt.Errorf("stream %s does not exist", streamID) return fmt.Errorf("stream %s does not exist", streamID)
} }
@ -176,8 +170,10 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error {
} }
func (t *Tunnel) CloseStream(streamID string) error { func (t *Tunnel) CloseStream(streamID string) error {
if conn, ok := t.streams.Get(streamID); ok { t.streamsMu.Lock()
t.streams.Delete(streamID) defer t.streamsMu.Unlock()
if conn, ok := t.streams[streamID]; ok {
delete(t.streams, streamID)
return conn.Close() return conn.Close()
} }
return nil return nil
@ -188,7 +184,17 @@ func (t *Tunnel) Source() string {
} }
func (t *Tunnel) sendWS(msg *types.Message) error { func (t *Tunnel) sendWS(msg *types.Message) error {
t.mu.Lock() t.wsMu.Lock()
defer t.mu.Unlock() defer t.wsMu.Unlock()
return t.wsConn.WriteJSON(msg) return t.wsConn.WriteJSON(msg)
} }
func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) {
t.streamsMu.Lock()
defer t.streamsMu.Unlock()
if conn, ok := t.streams[streamID]; ok {
return conn, true
}
return nil, false
}

View File

@ -10,6 +10,5 @@ const (
type Message struct { type Message struct {
Type MessageType `json:"type"` Type MessageType `json:"type"`
StreamID string `json:"stream_id"` StreamID string `json:"stream_id"`
SourceAddr string `json:"source_addr"`
Data []byte `json:"data,omitempty"` Data []byte `json:"data,omitempty"`
} }