fix(tunnel): stabilize concurrent stream handling
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -65,7 +65,6 @@ func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, er
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams *maps.Map[string, Stream]
|
||||
@@ -78,8 +77,6 @@ func (t *Tunnel) Start(ctx context.Context) {
|
||||
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
|
||||
t.ctx = ctx
|
||||
|
||||
// Start Message Receiver
|
||||
for {
|
||||
msg, err := t.readWSWithContext(ctx)
|
||||
@@ -94,7 +91,7 @@ func (t *Tunnel) Start(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Get Stream
|
||||
stream, err := t.getStream(msg.StreamID, msg.SourceAddr)
|
||||
stream, err := t.getStream(ctx, msg.StreamID, msg.SourceAddr)
|
||||
if err != nil {
|
||||
if msg.Type != types.MessageTypeClose {
|
||||
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||
@@ -151,13 +148,13 @@ func (t *Tunnel) Source() string {
|
||||
return t.wsConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
||||
func (t *Tunnel) StartStream(ctx context.Context, stream Stream, streamID string) error {
|
||||
// Close Stream
|
||||
defer t.closeStream(stream, streamID)
|
||||
|
||||
// Start Stream
|
||||
for {
|
||||
data, err := t.readStreamWithContext(t.ctx, stream)
|
||||
data, err := t.readStreamWithContext(ctx, stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,7 +176,7 @@ func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
||||
func (t *Tunnel) getStream(ctx context.Context, streamID, sourceAddress string) (Stream, error) {
|
||||
// Check Existing Stream
|
||||
if stream, found := t.streams.Get(streamID); found {
|
||||
return stream, nil
|
||||
@@ -198,7 +195,7 @@ func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
||||
if err := t.AddStream(stream, streamID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go t.StartStream(stream, streamID)
|
||||
go t.StartStream(ctx, stream, streamID)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user