config & auth

This commit is contained in:
2025-09-20 16:14:10 -04:00
parent 5d9684b27e
commit 2fba07f4b3
6 changed files with 390 additions and 192 deletions

View File

@@ -13,23 +13,26 @@ import (
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/config"
"reichard.io/conduit/types"
)
type TunnelConnection struct {
*websocket.Conn
name string
streams map[string]chan []byte // StreamID -> data channel
streams map[string]chan []byte
}
type Server struct {
tunnels map[string]*TunnelConnection
upgrader websocket.Upgrader
cfg *config.ServerConfig
mu sync.RWMutex
}
func NewServer() *Server {
func NewServer(cfg *config.ServerConfig) *Server {
return &Server{
cfg: cfg,
tunnels: make(map[string]*TunnelConnection),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
@@ -39,20 +42,21 @@ func NewServer() *Server {
}
}
func (s *Server) Start(addr string) error {
// Raw TCP listener instead of http.ListenAndServe
listener, err := net.Listen("tcp", addr)
func (s *Server) Start() error {
// Raw TCP Listener - This is necessary so we can conditionally either relay
// the raw TCP connection, or handle conduit control server API requests.
listener, err := net.Listen("tcp", s.cfg.BindAddress)
if err != nil {
return err
}
defer listener.Close()
log.Infof("Conduit server listening on %s", addr)
// Start Listening
log.Infof("conduit server listening on %s", s.cfg.BindAddress)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error accepting connection: %v", err)
log.Printf("error accepting connection: %v", err)
continue
}
@@ -76,7 +80,7 @@ func (s *Server) extractSubdomain(peakReader io.Reader) string {
// Extract Subdomain
parts := strings.Split(host, ".")
if len(parts) >= 1 {
if len(parts) > 1 {
return parts[0]
}
@@ -152,9 +156,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
}
}
// peakData limits how much we read as we only need to determine
// the host to figure out whether we should proxy or not.
func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
peek := make([]byte, 8192)
n, err := conn.Read(peek)
if err != nil {
@@ -163,13 +165,13 @@ func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
peekedData := peek[:n]
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
return bytes.NewReader(peekedData), combinedReader, nil
}
func (s *Server) handleRawConnection(conn net.Conn) {
defer conn.Close()
// Detect Tunnel
peakReader, allReader, _ := s.peekData(conn)
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
s.mu.RLock()
@@ -177,25 +179,32 @@ func (s *Server) handleRawConnection(conn net.Conn) {
s.mu.RUnlock()
if exists {
log.Infof("Relaying %s to tunnel", subdomain)
log.Infof("relaying %s to tunnel", subdomain)
s.proxyRawConnection(conn, tunnelConn, allReader)
return
}
return
}
// Otherwise, handle as control server (recreate HTTP request and use net/http)
// Control Endpoints
s.handleAsHTTP(conn, allReader)
}
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
// Create HTTP Request & Writer
w := &connResponseWriter{conn: conn}
r, err := http.ReadRequest(bufio.NewReader(allReader))
if err != nil {
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
w.WriteHeader(http.StatusBadRequest)
return
}
// Authorize Control Endpoints
apiKey := r.URL.Query().Get("apiKey")
if apiKey != s.cfg.APIKey {
log.Error("unauthorized client")
w.WriteHeader(http.StatusUnauthorized)
return
}
w := &connResponseWriter{conn: conn}
// Handle Control Endpoints
switch r.URL.Path {
@@ -217,7 +226,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
}
if msg.StreamID == "" {
log.Infof("Tunnel %s missing streamID", tunnel.name)
log.Infof("tunnel %s missing streamID", tunnel.name)
continue
}
@@ -228,7 +237,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
s.mu.RLock()
streamChan, exists := tunnel.streams[msg.StreamID]
if !exists {
log.Infof("Stream %s does not exist", msg.StreamID)
log.Infof("stream %s does not exist", msg.StreamID)
s.mu.RUnlock()
continue
}
@@ -236,7 +245,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
select {
case streamChan <- msg.Data:
case <-time.After(time.Second):
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
log.Warnf("stream %s channel full, dropping data", msg.StreamID)
}
s.mu.RUnlock()
}
@@ -261,7 +270,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
// Upgrade Connection
wsConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Errorf("WebSocket upgrade failed: %v", err)
log.Errorf("websocket upgrade failed: %v", err)
return
}
@@ -274,7 +283,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
s.tunnels[tunnelName] = tunnel
s.mu.Unlock()
log.Infof("Tunnel established: %s", tunnelName)
log.Infof("tunnel established: %s", tunnelName)
// Keep connection alive and handle cleanup
defer func() {
@@ -282,7 +291,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
delete(s.tunnels, tunnelName)
s.mu.Unlock()
_ = wsConn.Close()
log.Infof("Tunnel closed: %s", tunnelName)
log.Infof("tunnel closed: %s", tunnelName)
}()
// Handle tunnel messages