chore: move to sync map
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
@@ -17,8 +18,8 @@ type ConnBuilder func() (conn io.ReadWriteCloser, err error)
|
||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
wsConn: wsConn,
|
||||
streams: make(map[string]io.ReadWriteCloser),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +47,7 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
wsConn: wsConn,
|
||||
streams: make(map[string]io.ReadWriteCloser),
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
connBuilder: connBuilder,
|
||||
}, nil
|
||||
}
|
||||
@@ -54,10 +55,10 @@ func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, erro
|
||||
type Tunnel struct {
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams map[string]io.ReadWriteCloser
|
||||
streams *maps.Map[string, io.ReadWriteCloser]
|
||||
connBuilder ConnBuilder
|
||||
|
||||
wsMu, streamsMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *Tunnel) Start() {
|
||||
@@ -95,7 +96,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, found := t.getStream(streamID); found {
|
||||
if _, found := t.streams.Get(streamID); found {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -113,19 +114,16 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
|
||||
if _, found := t.streams[streamID]; found {
|
||||
if t.streams.HasKey(streamID) {
|
||||
return fmt.Errorf("stream %s already exists", streamID)
|
||||
}
|
||||
t.streams[streamID] = conn
|
||||
t.streams.Set(streamID, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(streamID string) error {
|
||||
// Get Stream
|
||||
conn, found := t.getStream(streamID)
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
@@ -160,7 +158,7 @@ func (t *Tunnel) StartStream(streamID string) error {
|
||||
|
||||
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||
// Get Stream
|
||||
conn, found := t.getStream(streamID)
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
@@ -170,10 +168,8 @@ func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||
}
|
||||
|
||||
func (t *Tunnel) CloseStream(streamID string) error {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
if conn, ok := t.streams[streamID]; ok {
|
||||
delete(t.streams, streamID)
|
||||
if conn, ok := t.streams.Get(streamID); ok {
|
||||
t.streams.Delete(streamID)
|
||||
return conn.Close()
|
||||
}
|
||||
return nil
|
||||
@@ -184,17 +180,7 @@ func (t *Tunnel) Source() string {
|
||||
}
|
||||
|
||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||
t.wsMu.Lock()
|
||||
defer t.wsMu.Unlock()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.wsConn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID string) (io.ReadWriteCloser, bool) {
|
||||
t.streamsMu.Lock()
|
||||
defer t.streamsMu.Unlock()
|
||||
|
||||
if conn, ok := t.streams[streamID]; ok {
|
||||
return conn, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user