clean
This commit is contained in:
parent
2fba07f4b3
commit
08e1191ba3
@ -19,8 +19,13 @@ var serveCmd = &cobra.Command{
|
||||
log.Fatal("failed to get server config:", err)
|
||||
}
|
||||
|
||||
// Create Server
|
||||
srv, err := server.NewServer(cfg)
|
||||
if err != nil {
|
||||
log.Fatal("failed to create server:", err)
|
||||
}
|
||||
|
||||
// Start Server
|
||||
srv := server.NewServer(cfg)
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal("failed to start server:", err)
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
@ -17,13 +19,19 @@ type ConfigDef struct {
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
ServerAddress string `json:"address" description:"Conduit server address" default:"http://localhost:8080"`
|
||||
ServerAddress string `json:"server" description:"Conduit server address" default:"http://localhost:8080"`
|
||||
APIKey string `json:"api_key" description:"API Key for the conduit API"`
|
||||
}
|
||||
|
||||
func (c *BaseConfig) Validate() error {
|
||||
if c.APIKey == "" {
|
||||
return fmt.Errorf("api_key is required")
|
||||
return errors.New("api_key is required")
|
||||
}
|
||||
if c.ServerAddress == "" {
|
||||
return errors.New("server is required")
|
||||
}
|
||||
if _, err := url.Parse(c.ServerAddress); err != nil {
|
||||
return fmt.Errorf("server is invalid: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -78,7 +86,7 @@ func GetClientConfig(cmdFlags *pflag.FlagSet) (*ClientConfig, error) {
|
||||
|
||||
cfg := &ClientConfig{
|
||||
BaseConfig: BaseConfig{
|
||||
ServerAddress: cfgValues["address"],
|
||||
ServerAddress: cfgValues["server"],
|
||||
APIKey: cfgValues["api_key"],
|
||||
},
|
||||
TunnelName: cfgValues["name"],
|
||||
|
118
server/server.go
118
server/server.go
@ -3,10 +3,12 @@ package server
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -24,22 +26,32 @@ type TunnelConnection struct {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
tunnels map[string]*TunnelConnection
|
||||
upgrader websocket.Upgrader
|
||||
host string
|
||||
cfg *config.ServerConfig
|
||||
mu sync.RWMutex
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
tunnels map[string]*TunnelConnection
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server address: %v", err)
|
||||
} else if serverURL.Host == "" {
|
||||
return nil, errors.New("invalid server address")
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.ServerConfig) *Server {
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
host: serverURL.Host,
|
||||
tunnels: make(map[string]*TunnelConnection),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
@ -64,29 +76,6 @@ func (s *Server) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) extractSubdomain(peakReader io.Reader) string {
|
||||
// Read Request
|
||||
req, err := http.ReadRequest(bufio.NewReader(peakReader))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// Extract Host
|
||||
host := req.Host
|
||||
if idx := strings.Index(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// Extract Subdomain
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) > 1 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) getStatus(w http.ResponseWriter, _ *http.Request) {
|
||||
s.mu.RLock()
|
||||
count := len(s.tunnels)
|
||||
@ -156,48 +145,61 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) peekData(conn net.Conn) (peekReader io.Reader, allReader io.Reader, err error) {
|
||||
peek := make([]byte, 8192)
|
||||
n, err := conn.Read(peek)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peekedData := peek[:n]
|
||||
combinedReader := io.MultiReader(bytes.NewReader(peekedData), conn)
|
||||
return bytes.NewReader(peekedData), combinedReader, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Detect Tunnel
|
||||
peakReader, allReader, _ := s.peekData(conn)
|
||||
if subdomain := s.extractSubdomain(peakReader); subdomain != "" {
|
||||
s.mu.RLock()
|
||||
tunnelConn, exists := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
// Capture Consumed Data - When determining where to route the request, we
|
||||
// have to read the host headers. This requires reading from the buffer, so
|
||||
// if we later decide to tunnel the TCP connection we need to reconstruct the
|
||||
// data from the buffer.
|
||||
var capturedData bytes.Buffer
|
||||
teeReader := io.TeeReader(conn, &capturedData)
|
||||
bufReader := bufio.NewReader(teeReader)
|
||||
|
||||
if exists {
|
||||
log.Infof("relaying %s to tunnel", subdomain)
|
||||
s.proxyRawConnection(conn, tunnelConn, allReader)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Control Endpoints
|
||||
s.handleAsHTTP(conn, allReader)
|
||||
}
|
||||
|
||||
func (s *Server) handleAsHTTP(conn net.Conn, allReader io.Reader) {
|
||||
// Create HTTP Request & Writer
|
||||
w := &connResponseWriter{conn: conn}
|
||||
r, err := http.ReadRequest(bufio.NewReader(allReader))
|
||||
r, err := http.ReadRequest(bufReader)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
// Validate Host
|
||||
if !strings.Contains(r.Host, s.host) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = fmt.Fprintf(w, "unknown host: %s", r.Host)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract Subdomain
|
||||
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||
if strings.Count(subdomain, ".") != 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Control Endpoints
|
||||
if subdomain == "" {
|
||||
s.handleAsHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Tunnels
|
||||
s.mu.RLock()
|
||||
tunnelConn, exists := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
if exists {
|
||||
log.Infof("relaying %s to tunnel", subdomain)
|
||||
|
||||
// Reconstruct Data & Proxy Connection
|
||||
allReader := io.MultiReader(&capturedData, r.Body)
|
||||
s.proxyRawConnection(conn, tunnelConn, allReader)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Authorize Control Endpoints
|
||||
apiKey := r.URL.Query().Get("apiKey")
|
||||
if apiKey != s.cfg.APIKey {
|
||||
|
Loading…
x
Reference in New Issue
Block a user