config & auth
This commit is contained in:
@@ -13,23 +13,26 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
type TunnelConnection struct {
|
||||
*websocket.Conn
|
||||
name string
|
||||
streams map[string]chan []byte // StreamID -> data channel
|
||||
streams map[string]chan []byte
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
tunnels map[string]*TunnelConnection
|
||||
upgrader websocket.Upgrader
|
||||
cfg *config.ServerConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewServer() *Server {
|
||||
func NewServer(cfg *config.ServerConfig) *Server {
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
tunnels: make(map[string]*TunnelConnection),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
@@ -39,20 +42,21 @@ func NewServer() *Server {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Start(addr string) error {
|
||||
// Raw TCP listener instead of http.ListenAndServe
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
func (s *Server) Start() error {
|
||||
// Raw TCP Listener - This is necessary so we can conditionally either relay
|
||||
// the raw TCP connection, or handle conduit control server API requests.
|
||||
listener, err := net.Listen("tcp", s.cfg.BindAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
log.Infof("Conduit server listening on %s", addr)
|
||||
|
||||
// Start Listening
|
||||
log.Infof("conduit server listening on %s", s.cfg.BindAddress)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
log.Printf("error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -76,7 +80,7 @@ func (s *Server) extractSubdomain(peakReader io.Reader) string {
|
||||
|
||||
// Extract Subdomain
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) >= 1 {
|
||||
if len(parts) > 1 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
@@ -152,9 +156,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
||||
}
|
||||
}
|
||||
|
||||
// peakData limits how much we read as we only need to determine
|
||||
// the host to figure out whether we should proxy or not.
|
||||
func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
|
||||
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
|
||||
peek := make([]byte, 8192)
|
||||
n, err := conn.Read(peek)
|
||||
if err != nil {
|
||||
@@ -163,13 +165,13 @@ func (s *Server) peekData(conn net.Conn) (io.Reader, io.Reader, error) {
|
||||
|
||||
peekedData := peek[:n]
|
||||
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
|
||||
|
||||
return bytes.NewReader(peekedData), combinedReader, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Detect Tunnel
|
||||
peakReader, allReader, _ := s.peekData(conn)
|
||||
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
|
||||
s.mu.RLock()
|
||||
@@ -177,25 +179,32 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
s.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
log.Infof("Relaying %s to tunnel", subdomain)
|
||||
|
||||
log.Infof("relaying %s to tunnel", subdomain)
|
||||
s.proxyRawConnection(conn, tunnelConn, allReader)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, handle as control server (recreate HTTP request and use net/http)
|
||||
// Control Endpoints
|
||||
s.handleAsHTTP(conn, allReader)
|
||||
}
|
||||
|
||||
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
|
||||
// Create HTTP Request & Writer
|
||||
w := &connResponseWriter{conn: conn}
|
||||
r, err := http.ReadRequest(bufio.NewReader(allReader))
|
||||
if err != nil {
|
||||
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authorize Control Endpoints
|
||||
apiKey := r.URL.Query().Get("apiKey")
|
||||
if apiKey != s.cfg.APIKey {
|
||||
log.Error("unauthorized client")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w := &connResponseWriter{conn: conn}
|
||||
|
||||
// Handle Control Endpoints
|
||||
switch r.URL.Path {
|
||||
@@ -217,7 +226,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
||||
}
|
||||
|
||||
if msg.StreamID == "" {
|
||||
log.Infof("Tunnel %s missing streamID", tunnel.name)
|
||||
log.Infof("tunnel %s missing streamID", tunnel.name)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -228,7 +237,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
||||
s.mu.RLock()
|
||||
streamChan, exists := tunnel.streams[msg.StreamID]
|
||||
if !exists {
|
||||
log.Infof("Stream %s does not exist", msg.StreamID)
|
||||
log.Infof("stream %s does not exist", msg.StreamID)
|
||||
s.mu.RUnlock()
|
||||
continue
|
||||
}
|
||||
@@ -236,7 +245,7 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
||||
select {
|
||||
case streamChan <- msg.Data:
|
||||
case <-time.After(time.Second):
|
||||
log.Infof("Stream %s channel full, dropping data", msg.StreamID)
|
||||
log.Warnf("stream %s channel full, dropping data", msg.StreamID)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
@@ -261,7 +270,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
// Upgrade Connection
|
||||
wsConn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Errorf("WebSocket upgrade failed: %v", err)
|
||||
log.Errorf("websocket upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -274,7 +283,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
s.tunnels[tunnelName] = tunnel
|
||||
s.mu.Unlock()
|
||||
log.Infof("Tunnel established: %s", tunnelName)
|
||||
log.Infof("tunnel established: %s", tunnelName)
|
||||
|
||||
// Keep connection alive and handle cleanup
|
||||
defer func() {
|
||||
@@ -282,7 +291,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
delete(s.tunnels, tunnelName)
|
||||
s.mu.Unlock()
|
||||
_ = wsConn.Close()
|
||||
log.Infof("Tunnel closed: %s", tunnelName)
|
||||
log.Infof("tunnel closed: %s", tunnelName)
|
||||
}()
|
||||
|
||||
// Handle tunnel messages
|
||||
|
||||
Reference in New Issue
Block a user