feat(tunnel): require explicit target schemes
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -2,7 +2,9 @@ package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
@@ -21,15 +23,24 @@ type Forwarder interface {
|
||||
}
|
||||
|
||||
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||
// Only parse as URL for HTTP targets. Bare host:port (e.g., "127.0.0.1:5432")
|
||||
// is not a valid URL and should be treated as a raw TCP target.
|
||||
targetURL, err := url.Parse(target)
|
||||
if err == nil {
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
return newHTTPForwarder(targetURL, tunnelStore)
|
||||
}
|
||||
if !strings.Contains(target, "://") {
|
||||
return nil, fmt.Errorf("target must include a scheme: tcp://, http://, or https://")
|
||||
}
|
||||
|
||||
return newTCPForwarder(target, tunnelStore), nil
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("target is invalid: %w", err)
|
||||
}
|
||||
if targetURL.Host == "" {
|
||||
return nil, fmt.Errorf("target must include a host")
|
||||
}
|
||||
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
return newHTTPForwarder(targetURL, tunnelStore)
|
||||
case "tcp":
|
||||
return newTCPForwarder(targetURL.Host, tunnelStore), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported target scheme %q: use tcp://, http://, or https://", targetURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
45
tunnel/forwarder_test.go
Normal file
45
tunnel/forwarder_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
func TestNewForwarderRequiresExplicitScheme(t *testing.T) {
|
||||
_, err := NewForwarder("localhost:8282", store.NewTunnelStore(1))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for target without scheme")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForwarderSupportsExplicitSchemes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
forwarderType ForwarderType
|
||||
}{
|
||||
{name: "http", target: "http://localhost:8282", forwarderType: ForwarderHTTP},
|
||||
{name: "https", target: "https://localhost:8282", forwarderType: ForwarderHTTP},
|
||||
{name: "tcp", target: "tcp://localhost:8282", forwarderType: ForwarderTCP},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder, err := NewForwarder(tt.target, store.NewTunnelStore(1))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if forwarder.Type() != tt.forwarderType {
|
||||
t.Fatalf("expected forwarder type %v, got %v", tt.forwarderType, forwarder.Type())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForwarderRejectsUnsupportedScheme(t *testing.T) {
|
||||
_, err := NewForwarder("udp://localhost:8282", store.NewTunnelStore(1))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unsupported scheme")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user