Compare commits
2 Commits
de23b3e815
...
20c1388cf4
Author | SHA1 | Date | |
---|---|---|---|
20c1388cf4 | |||
0333680a2b |
@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
|
|||||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn)
|
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn)
|
||||||
}
|
}
|
||||||
|
51
pkg/maps/map.go
Normal file
51
pkg/maps/map.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package maps
|
||||||
|
|
||||||
|
import (
|
||||||
|
"iter"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Map[K comparable, V any] struct {
|
||||||
|
items map[K]V
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func New[K comparable, V any]() *Map[K, V] {
|
||||||
|
return &Map[K, V]{items: make(map[K]V)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
v, ok := m.items[key]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Set(key K, value V) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.items[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Delete(key K) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.items, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) HasKey(key K) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
_, ok := m.items[key]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
|
||||||
|
return func(yield func(K, V) bool) {
|
||||||
|
for k, v := range m.items {
|
||||||
|
if !yield(k, v) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -11,12 +11,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"reichard.io/conduit/config"
|
"reichard.io/conduit/config"
|
||||||
|
"reichard.io/conduit/pkg/maps"
|
||||||
"reichard.io/conduit/tunnel"
|
"reichard.io/conduit/tunnel"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,10 +33,9 @@ type TunnelInfo struct {
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
host string
|
host string
|
||||||
cfg *config.ServerConfig
|
cfg *config.ServerConfig
|
||||||
mu sync.RWMutex
|
|
||||||
|
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
tunnels map[string]*tunnel.Tunnel
|
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||||
@ -50,7 +49,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
|||||||
return &Server{
|
return &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
host: serverURL.Host,
|
host: serverURL.Host,
|
||||||
tunnels: make(map[string]*tunnel.Tunnel),
|
tunnels: maps.New[string, *tunnel.Tunnel](),
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true
|
return true
|
||||||
@ -84,14 +83,12 @@ func (s *Server) Start() error {
|
|||||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||||
// Get Tunnels
|
// Get Tunnels
|
||||||
var allTunnels []TunnelInfo
|
var allTunnels []TunnelInfo
|
||||||
s.mu.RLock()
|
for t, c := range s.tunnels.Entries() {
|
||||||
for t, c := range s.tunnels {
|
|
||||||
allTunnels = append(allTunnels, TunnelInfo{
|
allTunnels = append(allTunnels, TunnelInfo{
|
||||||
Name: t,
|
Name: t,
|
||||||
Target: c.Source(),
|
Target: c.Source(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
s.mu.RUnlock()
|
|
||||||
|
|
||||||
// Create Response
|
// Create Response
|
||||||
d, err := json.MarshalIndent(InfoResponse{
|
d, err := json.MarshalIndent(InfoResponse{
|
||||||
@ -138,26 +135,31 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract Subdomain
|
// Extract Subdomain
|
||||||
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||||
if strings.Count(subdomain, ".") != 0 {
|
if strings.Count(tunnelName, ".") != 0 {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get True Host
|
||||||
|
remoteHost := conn.RemoteAddr().String()
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
remoteHost = xff
|
||||||
|
}
|
||||||
|
r.RemoteAddr = remoteHost
|
||||||
|
|
||||||
// Handle Control Endpoints
|
// Handle Control Endpoints
|
||||||
if subdomain == "" {
|
if tunnelName == "" {
|
||||||
s.handleAsHTTP(w, r)
|
s.handleAsHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Tunnels
|
// Handle Tunnels
|
||||||
s.mu.RLock()
|
conduitTunnel, exists := s.tunnels.Get(tunnelName)
|
||||||
conduitTunnel, exists := s.tunnels[subdomain]
|
|
||||||
s.mu.RUnlock()
|
|
||||||
if !exists {
|
if !exists {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain)
|
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,8 +172,8 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("relaying %s to tunnel", subdomain)
|
log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr)
|
||||||
_ = conduitTunnel.StartStream(streamID)
|
_ = conduitTunnel.StartStream(streamID, r.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -204,7 +206,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate Unique
|
// Validate Unique
|
||||||
if _, exists := s.tunnels[tunnelName]; exists {
|
if _, exists := s.tunnels.Get(tunnelName); exists {
|
||||||
w.WriteHeader(http.StatusConflict)
|
w.WriteHeader(http.StatusConflict)
|
||||||
_, _ = w.Write([]byte("Tunnel already registered"))
|
_, _ = w.Write([]byte("Tunnel already registered"))
|
||||||
return
|
return
|
||||||
@ -219,18 +221,14 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Create Tunnel
|
// Create Tunnel
|
||||||
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
||||||
s.mu.Lock()
|
s.tunnels.Set(tunnelName, conduitTunnel)
|
||||||
s.tunnels[tunnelName] = conduitTunnel
|
log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr)
|
||||||
s.mu.Unlock()
|
|
||||||
log.Infof("tunnel established: %s", tunnelName)
|
|
||||||
|
|
||||||
// Start Tunnel - This is blocking
|
// Start Tunnel - This is blocking
|
||||||
conduitTunnel.Start()
|
conduitTunnel.Start()
|
||||||
|
|
||||||
// Cleanup Tunnel
|
// Cleanup Tunnel
|
||||||
s.mu.Lock()
|
s.tunnels.Delete(tunnelName)
|
||||||
delete(s.tunnels, tunnelName)
|
|
||||||
s.mu.Unlock()
|
|
||||||
_ = wsConn.Close()
|
_ = wsConn.Close()
|
||||||
log.Infof("tunnel closed: %s", tunnelName)
|
log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"reichard.io/conduit/pkg/maps"
|
||||||
"reichard.io/conduit/types"
|
"reichard.io/conduit/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,27 +18,33 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error)
|
|||||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||||
return &Tunnel{
|
return &Tunnel{
|
||||||
name: name,
|
name: name,
|
||||||
|
streams: maps.New[string, io.ReadWriteCloser](),
|
||||||
wsConn: wsConn,
|
wsConn: wsConn,
|
||||||
streams: make(map[string]io.ReadWriteCloser),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) {
|
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
|
||||||
|
// Get Target URL
|
||||||
targetURL, err := url.Parse(target)
|
targetURL, err := url.Parse(target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Derive Conduit URL
|
||||||
|
conduitURL := *serverURL
|
||||||
|
conduitURL.Host = name + "." + conduitURL.Host
|
||||||
|
|
||||||
|
// Get Connection Builder
|
||||||
var connBuilder ConnBuilder
|
var connBuilder ConnBuilder
|
||||||
switch targetURL.Scheme {
|
switch targetURL.Scheme {
|
||||||
case "http", "https":
|
case "http", "https":
|
||||||
log.Infof("creating HTTP tunnel: %s -> %s", name, target)
|
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
|
||||||
connBuilder, err = HTTPConnectionBuilder(targetURL)
|
connBuilder, err = HTTPConnectionBuilder(targetURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Infof("creating TCP tunnel: %s -> %s", name, target)
|
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
|
||||||
connBuilder = func() (conn io.ReadWriteCloser, err error) {
|
connBuilder = func() (conn io.ReadWriteCloser, err error) {
|
||||||
return net.Dial("tcp", target)
|
return net.Dial("tcp", target)
|
||||||
}
|
}
|
||||||
@ -46,7 +53,7 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
|||||||
return &Tunnel{
|
return &Tunnel{
|
||||||
name: name,
|
name: name,
|
||||||
wsConn: wsConn,
|
wsConn: wsConn,
|
||||||
streams: make(map[string]io.ReadWriteCloser),
|
streams: maps.New[string, io.ReadWriteCloser](),
|
||||||
connBuilder: connBuilder,
|
connBuilder: connBuilder,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -54,10 +61,10 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
|||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
name string
|
name string
|
||||||
wsConn *websocket.Conn
|
wsConn *websocket.Conn
|
||||||
streams map[string]io.ReadWriteCloser
|
streams *maps.Map[string, io.ReadWriteCloser]
|
||||||
connBuilder ConnBuilder
|
connBuilder ConnBuilder
|
||||||
|
|
||||||
wsMu, streamsMu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) Start() {
|
func (t *Tunnel) Start() {
|
||||||
@ -95,7 +102,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, found := t.getStream(streamID); found {
|
if _, found := t.streams.Get(streamID); found {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,24 +115,21 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
go t.StartStream(streamID)
|
go t.StartStream(streamID, "")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
||||||
t.streamsMu.Lock()
|
if t.streams.HasKey(streamID) {
|
||||||
defer t.streamsMu.Unlock()
|
|
||||||
|
|
||||||
if _, found := t.streams[streamID]; found {
|
|
||||||
return fmt.Errorf("stream %s already exists", streamID)
|
return fmt.Errorf("stream %s already exists", streamID)
|
||||||
}
|
}
|
||||||
t.streams[streamID] = conn
|
t.streams.Set(streamID, conn)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) StartStream(streamID string) error {
|
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
|
||||||
// Get Stream
|
// Get Stream
|
||||||
conn, found := t.getStream(streamID)
|
conn, found := t.streams.Get(streamID)
|
||||||
if !found {
|
if !found {
|
||||||
return fmt.Errorf("stream %s does not exist", streamID)
|
return fmt.Errorf("stream %s does not exist", streamID)
|
||||||
}
|
}
|
||||||
@ -133,8 +137,9 @@ func (t *Tunnel) StartStream(streamID string) error {
|
|||||||
// Close Stream
|
// Close Stream
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = t.sendWS(&types.Message{
|
_ = t.sendWS(&types.Message{
|
||||||
Type: types.MessageTypeClose,
|
Type: types.MessageTypeClose,
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
})
|
})
|
||||||
|
|
||||||
t.CloseStream(streamID)
|
t.CloseStream(streamID)
|
||||||
@ -149,9 +154,10 @@ func (t *Tunnel) StartStream(streamID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := t.sendWS(&types.Message{
|
if err := t.sendWS(&types.Message{
|
||||||
Type: types.MessageTypeData,
|
Type: types.MessageTypeData,
|
||||||
Data: buffer[:n],
|
StreamID: streamID,
|
||||||
StreamID: streamID,
|
Data: buffer[:n],
|
||||||
|
SourceAddr: sourceAddr,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -160,7 +166,7 @@ func (t *Tunnel) StartStream(streamID string) error {
|
|||||||
|
|
||||||
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||||
// Get Stream
|
// Get Stream
|
||||||
conn, found := t.getStream(streamID)
|
conn, found := t.streams.Get(streamID)
|
||||||
if !found {
|
if !found {
|
||||||
return fmt.Errorf("stream %s does not exist", streamID)
|
return fmt.Errorf("stream %s does not exist", streamID)
|
||||||
}
|
}
|
||||||
@ -170,10 +176,8 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) CloseStream(streamID string) error {
|
func (t *Tunnel) CloseStream(streamID string) error {
|
||||||
t.streamsMu.Lock()
|
if conn, ok := t.streams.Get(streamID); ok {
|
||||||
defer t.streamsMu.Unlock()
|
t.streams.Delete(streamID)
|
||||||
if conn, ok := t.streams[streamID]; ok {
|
|
||||||
delete(t.streams, streamID)
|
|
||||||
return conn.Close()
|
return conn.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -184,17 +188,7 @@ func (t *Tunnel) Source() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||||
t.wsMu.Lock()
|
t.mu.Lock()
|
||||||
defer t.wsMu.Unlock()
|
defer t.mu.Unlock()
|
||||||
return t.wsConn.WriteJSON(msg)
|
return t.wsConn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) {
|
|
||||||
t.streamsMu.Lock()
|
|
||||||
defer t.streamsMu.Unlock()
|
|
||||||
|
|
||||||
if conn, ok := t.streams[streamID]; ok {
|
|
||||||
return conn, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
@ -8,7 +8,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Type MessageType `json:"type"`
|
Type MessageType `json:"type"`
|
||||||
StreamID string `json:"stream_id"`
|
StreamID string `json:"stream_id"`
|
||||||
Data []byte `json:"data,omitempty"`
|
SourceAddr string `json:"source_addr"`
|
||||||
|
Data []byte `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user