diff --git a/AGENTS.md b/AGENTS.md index 5de3dbb..4db31cc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,7 +11,7 @@ Conduit is a self-hosted tunneling service (Go, single binary). A **server** (`c make build_local # Run tests -make tests # includes coverage +make tests # includes race detection + coverage # Lint golangci-lint run @@ -59,7 +59,7 @@ pkg/maps/map.go — Generic sync.RWMutex-guarded map ## Testing -E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports. +E2E tests live in `e2e_test.go` at the project root. They spin up real servers, tunnels, and targets on random ports. `make tests` runs with `-race` and coverage enabled. ```bash # Run all tests diff --git a/Makefile b/Makefile index 191502d..650f784 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,6 @@ clean: rm -rf ./build tests: - SET_TEST=set_val go test -coverpkg=./... ./... -coverprofile=./cover.out + SET_TEST=set_val go test -race -coverpkg=./... ./... -coverprofile=./cover.out go tool cover -html=./cover.out -o ./cover.html rm ./cover.out diff --git a/server/server.go b/server/server.go index 9be4608..e4e756d 100644 --- a/server/server.go +++ b/server/server.go @@ -9,7 +9,7 @@ import ( "net/http" "net/url" "strings" - "time" + "sync/atomic" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" @@ -35,6 +35,7 @@ type Server struct { upgrader websocket.Upgrader tunnels *maps.Map[string, *tunnel.Tunnel] + streamID atomic.Uint64 } func NewServer(ctx context.Context, cfg *config.ServerConfig) (*Server, error) { @@ -145,7 +146,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun reconstructedConn := newReconstructedConn(conn, &reqBuf, bufrw) // Create Stream - streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano()) + streamID := fmt.Sprintf("stream_%d", s.streamID.Add(1)) tunnelStream := tunnel.NewStream(reconstructedConn, r.RemoteAddr, conduitTunnel.Source()) // Add Stream @@ -156,7 +157,7 @@ func (s *Server) handleTunnelRequest(w http.ResponseWriter, r *http.Request, tun } // Start Stream - conduitTunnel.StartStream(tunnelStream, streamID) + conduitTunnel.StartStream(s.ctx, tunnelStream, streamID) } func (s *Server) getInfo(w http.ResponseWriter, _ *http.Request) { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 59a28df..d3b953e 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -65,7 +65,6 @@ func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, er } type Tunnel struct { - ctx context.Context name string wsConn *websocket.Conn streams *maps.Map[string, Stream] @@ -78,8 +77,6 @@ func (t *Tunnel) Start(ctx context.Context) { log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String()) defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String()) - t.ctx = ctx - // Start Message Receiver for { msg, err := t.readWSWithContext(ctx) @@ -94,7 +91,7 @@ func (t *Tunnel) Start(ctx context.Context) { } // Get Stream - stream, err := t.getStream(msg.StreamID, msg.SourceAddr) + stream, err := t.getStream(ctx, msg.StreamID, msg.SourceAddr) if err != nil { if msg.Type != types.MessageTypeClose { log.WithError(err).Errorf("failed to get stream %s", msg.StreamID) @@ -151,13 +148,13 @@ func (t *Tunnel) Source() string { return t.wsConn.RemoteAddr().String() } -func (t *Tunnel) StartStream(stream Stream, streamID string) error { +func (t *Tunnel) StartStream(ctx context.Context, stream Stream, streamID string) error { // Close Stream defer t.closeStream(stream, streamID) // Start Stream for { - data, err := t.readStreamWithContext(t.ctx, stream) + data, err := t.readStreamWithContext(ctx, stream) if err != nil { return err } @@ -179,7 +176,7 @@ func (t *Tunnel) closeStream(stream Stream, streamID string) error { return stream.Close() } -func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) { +func (t *Tunnel) getStream(ctx context.Context, streamID, sourceAddress string) (Stream, error) { // Check Existing Stream if stream, found := t.streams.Get(streamID); found { return stream, nil @@ -198,7 +195,7 @@ func (t *Tunnel) getStream(streamID, sourceAddress string) (Stream, error) { if err := t.AddStream(stream, streamID); err != nil { return nil, err } - go t.StartStream(stream, streamID) + go t.StartStream(ctx, stream, streamID) return stream, nil }