fix(tunnel): stabilize concurrent stream handling
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -11,7 +11,7 @@ Conduit is a self-hosted tunneling service (Go, single binary). A **server** (`c
|
|||||||
make build_local
|
make build_local
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
make tests # includes coverage
|
make tests # includes race detection + coverage
|
||||||
|
|
||||||
# Lint
|
# Lint
|
||||||
golangci-lint run
|
golangci-lint run
|
||||||
@@ -59,7 +59,7 @@ pkg/maps/map.go — Generic sync.RWMutex-guarded map
|
|||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
|
|
||||||
E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports.
|
E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports. `make tests` runs with `-race` and coverage enabled.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run all tests
|
# Run all tests
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -30,6 +30,6 @@ clean:
|
|||||||
rm -rf ./build
|
rm -rf ./build
|
||||||
|
|
||||||
tests:
|
tests:
|
||||||
SET_TEST=set_val go test -coverpkg=./... ./... -coverprofile=./cover.out
|
SET_TEST=set_val go test -race -coverpkg=./... ./... -coverprofile=./cover.out
|
||||||
go tool cover -html=./cover.out -o ./cover.html
|
go tool cover -html=./cover.out -o ./cover.html
|
||||||
rm ./cover.out
|
rm ./cover.out
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -35,6 +35,7 @@ type Server struct {
|
|||||||
|
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
tunnels *maps.Map[string, *tunnel.Tunnel]
|
tunnels *maps.Map[string, *tunnel.Tunnel]
|
||||||
|
streamID atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) {
|
||||||
@@ -145,7 +146,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun
|
|||||||
reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw)
|
reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw)
|
||||||
|
|
||||||
// Create Stream
|
// Create Stream
|
||||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
streamID := fmt.Sprintf("stream_%d", s.streamID.Add(1))
|
||||||
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source())
|
tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source())
|
||||||
|
|
||||||
// Add Stream
|
// Add Stream
|
||||||
@@ -156,7 +157,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start Stream
|
// Start Stream
|
||||||
conduitTunnel.StartStream(tunnelStream, streamID)
|
conduitTunnel.StartStream(s.ctx, tunnelStream, streamID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
ctx context.Context
|
|
||||||
name string
|
name string
|
||||||
wsConn *websocket.Conn
|
wsConn *websocket.Conn
|
||||||
streams *maps.Map[string, Stream]
|
streams *maps.Map[string, Stream]
|
||||||
@@ -78,8 +77,6 @@ func (t *Tunnel) Start(ctx context.Context) {
|
|||||||
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||||
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||||
|
|
||||||
t.ctx = ctx
|
|
||||||
|
|
||||||
// Start Message Receiver
|
// Start Message Receiver
|
||||||
for {
|
for {
|
||||||
msg, err := t.readWSWithContext(ctx)
|
msg, err := t.readWSWithContext(ctx)
|
||||||
@@ -94,7 +91,7 @@ func (t *Tunnel) Start(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get Stream
|
// Get Stream
|
||||||
stream, err := t.getStream(msg.StreamID, msg.SourceAddr)
|
stream, err := t.getStream(ctx, msg.StreamID, msg.SourceAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if msg.Type != types.MessageTypeClose {
|
if msg.Type != types.MessageTypeClose {
|
||||||
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||||
@@ -151,13 +148,13 @@ func (t *Tunnel) Source() string {
|
|||||||
return t.wsConn.RemoteAddr().String()
|
return t.wsConn.RemoteAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
func (t *Tunnel) StartStream(ctx context.Context, stream Stream, streamID string) error {
|
||||||
// Close Stream
|
// Close Stream
|
||||||
defer t.closeStream(stream, streamID)
|
defer t.closeStream(stream, streamID)
|
||||||
|
|
||||||
// Start Stream
|
// Start Stream
|
||||||
for {
|
for {
|
||||||
data, err := t.readStreamWithContext(t.ctx, stream)
|
data, err := t.readStreamWithContext(ctx, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -179,7 +176,7 @@ func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
|||||||
return stream.Close()
|
return stream.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
func (t *Tunnel) getStream(ctx context.Context, streamID, sourceAddress string) (Stream, error) {
|
||||||
// Check Existing Stream
|
// Check Existing Stream
|
||||||
if stream, found := t.streams.Get(streamID); found {
|
if stream, found := t.streams.Get(streamID); found {
|
||||||
return stream, nil
|
return stream, nil
|
||||||
@@ -198,7 +195,7 @@ func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) {
|
|||||||
if err := t.AddStream(stream, streamID); err != nil {
|
if err := t.AddStream(stream, streamID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
go t.StartStream(stream, streamID)
|
go t.StartStream(ctx, stream, streamID)
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user