fix(tunnel): stabilize concurrent stream handling
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -11,7 +11,7 @@ Conduit is a self-hosted tunneling service (Go, single binary). A **server** (`c
|
||||
make build_local
|
||||
|
||||
# Run tests
|
||||
make tests # includes coverage
|
||||
make tests # includes race detection + coverage
|
||||
|
||||
# Lint
|
||||
golangci-lint run
|
||||
@@ -59,7 +59,7 @@ pkg/maps/map.go — Generic sync.RWMutex-guarded map
|
||||
|
||||
## Testing
|
||||
|
||||
E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports.
|
||||
E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports. `make tests` runs with `-race` and coverage enabled.
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
|
||||
2
Makefile
2
Makefile
@@ -30,6 +30,6 @@ clean:
|
||||
rm -rf ./build
|
||||
|
||||
tests:
|
||||
SET_TEST=set_val go test -coverpkg=./... ./... -coverprofile=./cover.out
|
||||
SET_TEST=set_val go test -race -coverpkg=./... ./... -coverprofile=./cover.out
|
||||
go tool cover -html=./cover.out -o ./cover.html
|
||||
rm ./cover.out
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -35,6 +35,7 @@ type Server struct {
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||
streamID atomic.Uint64
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
||||
@@ -145,7 +146,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun
|
||||
reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw)
|
||||
|
||||
// Create Stream
|
||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||
streamID := fmt.Sprintf("stream_%d", s.streamID.Add(1))
|
||||
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source())
|
||||
|
||||
// Add Stream
|
||||
@@ -156,7 +157,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun
|
||||
}
|
||||
|
||||
// Start Stream
|
||||
conduitTunnel.StartStream(tunnelStream, streamID)
|
||||
conduitTunnel.StartStream(s.ctx, tunnelStream, streamID)
|
||||
}
|
||||
|
||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||
|
||||
@@ -65,7 +65,6 @@ func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, er
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
ctx context.Context
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams *maps.Map[string, Stream]
|
||||
@@ -78,8 +77,6 @@ func (t *Tunnel) Start(ctx context.Context) {
|
||||
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
|
||||
t.ctx = ctx
|
||||
|
||||
// Start Message Receiver
|
||||
for {
|
||||
msg, err := t.readWSWithContext(ctx)
|
||||
@@ -94,7 +91,7 @@ func (t *Tunnel) Start(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Get Stream
|
||||
stream, err := t.getStream(msg.StreamID, msg.SourceAddr)
|
||||
stream, err := t.getStream(ctx, msg.StreamID, msg.SourceAddr)
|
||||
if err != nil {
|
||||
if msg.Type != types.MessageTypeClose {
|
||||
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||
@@ -151,13 +148,13 @@ func (t *Tunnel) Source() string {
|
||||
return t.wsConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
||||
func (t *Tunnel) StartStream(ctx context.Context, stream Stream, streamID string) error {
|
||||
// Close Stream
|
||||
defer t.closeStream(stream, streamID)
|
||||
|
||||
// Start Stream
|
||||
for {
|
||||
data, err := t.readStreamWithContext(t.ctx, stream)
|
||||
data, err := t.readStreamWithContext(ctx, stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -179,7 +176,7 @@ func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
||||
func (t *Tunnel) getStream(ctx context.Context, streamID, sourceAddress string) (Stream, error) {
|
||||
// Check Existing Stream
|
||||
if stream, found := t.streams.Get(streamID); found {
|
||||
return stream, nil
|
||||
@@ -198,7 +195,7 @@ func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
||||
if err := t.AddStream(stream, streamID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go t.StartStream(stream, streamID)
|
||||
go t.StartStream(ctx, stream, streamID)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user