diff --git a/pkg/maps/map.go b/pkg/maps/map.go new file mode 100644 index 0000000..3cbedb7 --- /dev/null +++ b/pkg/maps/map.go @@ -0,0 +1,51 @@ +package maps + +import ( + "iter" + "sync" +) + +type Map[K comparable, V any] struct { + items map[K]V + mu sync.RWMutex +} + +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{items: make(map[K]V)} +} + +func (m *Map[K, V]) Get(key K) (V, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + v, ok := m.items[key] + return v, ok +} + +func (m *Map[K, V]) Set(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.items[key] = value +} + +func (m *Map[K, V]) Delete(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.items, key) +} + +func (m *Map[K, V]) HasKey(key K) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.items[key] + return ok +} + +func (m *Map[K, V]) Entries() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for k, v := range m.items { + if !yield(k, v) { + return + } + } + } +} diff --git a/server/server.go b/server/server.go index 85412b2..57abd20 100644 --- a/server/server.go +++ b/server/server.go @@ -11,12 +11,12 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "reichard.io/conduit/config" + "reichard.io/conduit/pkg/maps" "reichard.io/conduit/tunnel" ) @@ -33,10 +33,9 @@ type TunnelInfo struct { type Server struct { host string cfg *config.ServerConfig - mu sync.RWMutex upgrader websocket.Upgrader - tunnels map[string]*tunnel.Tunnel + tunnels *maps.Map[string, *tunnel.Tunnel] } func NewServer(cfg *config.ServerConfig) (*Server, error) { @@ -50,7 +49,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) { return &Server{ cfg: cfg, host: serverURL.Host, - tunnels: make(map[string]*tunnel.Tunnel), + tunnels: maps.New[string, *tunnel.Tunnel](), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -84,14 +83,12 @@ func (s *Server) Start() error { func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { // Get Tunnels var allTunnels []TunnelInfo - s.mu.RLock() - for t, c := range s.tunnels { + for t, c := range s.tunnels.Entries() { allTunnels = append(allTunnels, TunnelInfo{ Name: t, Target: c.Source(), }) } - s.mu.RUnlock() // Create Response d, err := json.MarshalIndent(InfoResponse{ @@ -152,9 +149,7 @@ func (s *Server) handleRawConnection(conn net.Conn) { } // Handle Tunnels - s.mu.RLock() - conduitTunnel, exists := s.tunnels[subdomain] - s.mu.RUnlock() + conduitTunnel, exists := s.tunnels.Get(subdomain) if !exists { w.WriteHeader(http.StatusNotFound) _, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain) @@ -204,7 +199,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { } // Validate Unique - if _, exists := s.tunnels[tunnelName]; exists { + if _, exists := s.tunnels.Get(tunnelName); exists { w.WriteHeader(http.StatusConflict) _, _ = w.Write([]byte("Tunnel already registered")) return @@ -219,18 +214,14 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) { // Create Tunnel conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn) - s.mu.Lock() - s.tunnels[tunnelName] = conduitTunnel - s.mu.Unlock() + s.tunnels.Set(tunnelName, conduitTunnel) log.Infof("tunnel established: %s", tunnelName) // Start Tunnel - This is blocking conduitTunnel.Start() // Cleanup Tunnel - s.mu.Lock() - delete(s.tunnels, tunnelName) - s.mu.Unlock() + s.tunnels.Delete(tunnelName) _ = wsConn.Close() log.Infof("tunnel closed: %s", tunnelName) } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 19d2893..829d48b 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -9,6 +9,7 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" + "reichard.io/conduit/pkg/maps" "reichard.io/conduit/types" ) @@ -17,8 +18,8 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error) func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel { return &Tunnel{ name: name, + streams: maps.New[string, io.ReadWriteCloser](), wsConn: wsConn, - streams: make(map[string]io.ReadWriteCloser), } } @@ -46,7 +47,7 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro return &Tunnel{ name: name, wsConn: wsConn, - streams: make(map[string]io.ReadWriteCloser), + streams: maps.New[string, io.ReadWriteCloser](), connBuilder: connBuilder, }, nil } @@ -54,10 +55,10 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro type Tunnel struct { name string wsConn *websocket.Conn - streams map[string]io.ReadWriteCloser + streams *maps.Map[string, io.ReadWriteCloser] connBuilder ConnBuilder - wsMu, streamsMu sync.Mutex + mu sync.Mutex } func (t *Tunnel) Start() { @@ -95,7 +96,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error { return nil } - if _, found := t.getStream(streamID); found { + if _, found := t.streams.Get(streamID); found { return nil } @@ -113,19 +114,16 @@ func (t *Tunnel) initStreamConnection(streamID string) error { } func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error { - t.streamsMu.Lock() - defer t.streamsMu.Unlock() - - if _, found := t.streams[streamID]; found { + if t.streams.HasKey(streamID) { return fmt.Errorf("stream %s already exists", streamID) } - t.streams[streamID] = conn + t.streams.Set(streamID, conn) return nil } func (t *Tunnel) StartStream(streamID string) error { // Get Stream - conn, found := t.getStream(streamID) + conn, found := t.streams.Get(streamID) if !found { return fmt.Errorf("stream %s does not exist", streamID) } @@ -160,7 +158,7 @@ func (t *Tunnel) StartStream(streamID string) error { func (t *Tunnel) WriteStream(streamID string, data []byte) error { // Get Stream - conn, found := t.getStream(streamID) + conn, found := t.streams.Get(streamID) if !found { return fmt.Errorf("stream %s does not exist", streamID) } @@ -170,10 +168,8 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error { } func (t *Tunnel) CloseStream(streamID string) error { - t.streamsMu.Lock() - defer t.streamsMu.Unlock() - if conn, ok := t.streams[streamID]; ok { - delete(t.streams, streamID) + if conn, ok := t.streams.Get(streamID); ok { + t.streams.Delete(streamID) return conn.Close() } return nil @@ -184,17 +180,7 @@ func (t *Tunnel) Source() string { } func (t *Tunnel) sendWS(msg *types.Message) error { - t.wsMu.Lock() - defer t.wsMu.Unlock() + t.mu.Lock() + defer t.mu.Unlock() return t.wsConn.WriteJSON(msg) } - -func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) { - t.streamsMu.Lock() - defer t.streamsMu.Unlock() - - if conn, ok := t.streams[streamID]; ok { - return conn, true - } - return nil, false -}