chore: move to sync map

This commit is contained in:
2025-09-23 09:04:06 -04:00
parent de23b3e815
commit 0333680a2b
3 changed files with 73 additions and 45 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"reichard.io/conduit/pkg/maps"
"reichard.io/conduit/types"
)
@@ -17,8 +18,8 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error)
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
return &Tunnel{
name: name,
streams: maps.New[string, io.ReadWriteCloser](),
wsConn: wsConn,
streams: make(map[string]io.ReadWriteCloser),
}
}
@@ -46,7 +47,7 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
return &Tunnel{
name: name,
wsConn: wsConn,
streams: make(map[string]io.ReadWriteCloser),
streams: maps.New[string, io.ReadWriteCloser](),
connBuilder: connBuilder,
}, nil
}
@@ -54,10 +55,10 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
type Tunnel struct {
name string
wsConn *websocket.Conn
streams map[string]io.ReadWriteCloser
streams *maps.Map[string, io.ReadWriteCloser]
connBuilder ConnBuilder
wsMu, streamsMu sync.Mutex
mu sync.Mutex
}
func (t *Tunnel) Start() {
@@ -95,7 +96,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
return nil
}
if _, found := t.getStream(streamID); found {
if _, found := t.streams.Get(streamID); found {
return nil
}
@@ -113,19 +114,16 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
}
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
t.streamsMu.Lock()
defer t.streamsMu.Unlock()
if _, found := t.streams[streamID]; found {
if t.streams.HasKey(streamID) {
return fmt.Errorf("stream %s already exists", streamID)
}
t.streams[streamID] = conn
t.streams.Set(streamID, conn)
return nil
}
func (t *Tunnel) StartStream(streamID string) error {
// Get Stream
conn, found := t.getStream(streamID)
conn, found := t.streams.Get(streamID)
if !found {
return fmt.Errorf("stream %s does not exist", streamID)
}
@@ -160,7 +158,7 @@ func (t *Tunnel) StartStream(streamID string) error {
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
// Get Stream
conn, found := t.getStream(streamID)
conn, found := t.streams.Get(streamID)
if !found {
return fmt.Errorf("stream %s does not exist", streamID)
}
@@ -170,10 +168,8 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error {
}
func (t *Tunnel) CloseStream(streamID string) error {
t.streamsMu.Lock()
defer t.streamsMu.Unlock()
if conn, ok := t.streams[streamID]; ok {
delete(t.streams, streamID)
if conn, ok := t.streams.Get(streamID); ok {
t.streams.Delete(streamID)
return conn.Close()
}
return nil
@@ -184,17 +180,7 @@ func (t *Tunnel) Source() string {
}
func (t *Tunnel) sendWS(msg *types.Message) error {
t.wsMu.Lock()
defer t.wsMu.Unlock()
t.mu.Lock()
defer t.mu.Unlock()
return t.wsConn.WriteJSON(msg)
}
func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) {
t.streamsMu.Lock()
defer t.streamsMu.Unlock()
if conn, ok := t.streams[streamID]; ok {
return conn, true
}
return nil, false
}