chore: move to sync map
This commit is contained in:
parent
de23b3e815
commit
0333680a2b
51
pkg/maps/map.go
Normal file
51
pkg/maps/map.go
Normal file
@ -0,0 +1,51 @@
|
||||
package maps
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Map[K comparable, V any] struct {
|
||||
items map[K]V
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{items: make(map[K]V)}
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
v, ok := m.items[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.items[key] = value
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.items, key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) HasKey(key K) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.items[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Entries() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
for k, v := range m.items {
|
||||
if !yield(k, v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -11,12 +11,12 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/tunnel"
|
||||
)
|
||||
|
||||
@ -33,10 +33,9 @@ type TunnelInfo struct {
|
||||
type Server struct {
|
||||
host string
|
||||
cfg *config.ServerConfig
|
||||
mu sync.RWMutex
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
tunnels map[string]*tunnel.Tunnel
|
||||
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||
@ -50,7 +49,7 @@ func NewServer(cfg *config.ServerConfig) (*Server, error) {
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
host: serverURL.Host,
|
||||
tunnels: make(map[string]*tunnel.Tunnel),
|
||||
tunnels: maps.New[string, *tunnel.Tunnel](),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
@ -84,14 +83,12 @@ func (s *Server) Start() error {
|
||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||
// Get Tunnels
|
||||
var allTunnels []TunnelInfo
|
||||
s.mu.RLock()
|
||||
for t, c := range s.tunnels {
|
||||
for t, c := range s.tunnels.Entries() {
|
||||
allTunnels = append(allTunnels, TunnelInfo{
|
||||
Name: t,
|
||||
Target: c.Source(),
|
||||
})
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Create Response
|
||||
d, err := json.MarshalIndent(InfoResponse{
|
||||
@ -152,9 +149,7 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
}
|
||||
|
||||
// Handle Tunnels
|
||||
s.mu.RLock()
|
||||
conduitTunnel, exists := s.tunnels[subdomain]
|
||||
s.mu.RUnlock()
|
||||
conduitTunnel, exists := s.tunnels.Get(subdomain)
|
||||
if !exists {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain)
|
||||
@ -204,7 +199,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Validate Unique
|
||||
if _, exists := s.tunnels[tunnelName]; exists {
|
||||
if _, exists := s.tunnels.Get(tunnelName); exists {
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
_, _ = w.Write([]byte("Tunnel already registered"))
|
||||
return
|
||||
@ -219,18 +214,14 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Create Tunnel
|
||||
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
||||
s.mu.Lock()
|
||||
s.tunnels[tunnelName] = conduitTunnel
|
||||
s.mu.Unlock()
|
||||
s.tunnels.Set(tunnelName, conduitTunnel)
|
||||
log.Infof("tunnel established: %s", tunnelName)
|
||||
|
||||
// Start Tunnel - This is blocking
|
||||
conduitTunnel.Start()
|
||||
|
||||
// Cleanup Tunnel
|
||||
s.mu.Lock()
|
||||
delete(s.tunnels, tunnelName)
|
||||
s.mu.Unlock()
|
||||
s.tunnels.Delete(tunnelName)
|
||||
_ = wsConn.Close()
|
||||
log.Infof("tunnel closed: %s", tunnelName)
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
@ -17,8 +18,8 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error)
|
||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
wsConn: wsConn,
|
||||
streams: make(map[string]io.ReadWriteCloser),
|
||||
}
|
||||
}
|
||||
|
||||
@ -46,7 +47,7 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
wsConn: wsConn,
|
||||
streams: make(map[string]io.ReadWriteCloser),
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
connBuilder: connBuilder,
|
||||
}, nil
|
||||
}
|
||||
@ -54,10 +55,10 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
||||
type Tunnel struct {
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams map[string]io.ReadWriteCloser
|
||||
streams *maps.Map[string, io.ReadWriteCloser]
|
||||
connBuilder ConnBuilder
|
||||
|
||||
wsMu, streamsMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *Tunnel) Start() {
|
||||
@ -95,7 +96,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, found := t.getStream(streamID); found {
|
||||
if _, found := t.streams.Get(streamID); found {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -113,19 +114,16 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
|
||||
if _, found := t.streams[streamID]; found {
|
||||
if t.streams.HasKey(streamID) {
|
||||
return fmt.Errorf("stream %s already exists", streamID)
|
||||
}
|
||||
t.streams[streamID] = conn
|
||||
t.streams.Set(streamID, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(streamID string) error {
|
||||
// Get Stream
|
||||
conn, found := t.getStream(streamID)
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
@ -160,7 +158,7 @@ func (t *Tunnel) StartStream(streamID string) error {
|
||||
|
||||
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||
// Get Stream
|
||||
conn, found := t.getStream(streamID)
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
@ -170,10 +168,8 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) CloseStream(streamID string) error {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
if conn, ok := t.streams[streamID]; ok {
|
||||
delete(t.streams, streamID)
|
||||
if conn, ok := t.streams.Get(streamID); ok {
|
||||
t.streams.Delete(streamID)
|
||||
return conn.Close()
|
||||
}
|
||||
return nil
|
||||
@ -184,17 +180,7 @@ func (t *Tunnel) Source() string {
|
||||
}
|
||||
|
||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||
t.wsMu.Lock()
|
||||
defer t.wsMu.Unlock()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.wsConn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
|
||||
if conn, ok := t.streams[streamID]; ok {
|
||||
return conn, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user