chore: better source tracking
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
0333680a2b
commit
20c1388cf4
@ -41,5 +41,5 @@ func NewTunnel(cfg *config.ClientConfig) (*tunnel.Tunnel, error) {
|
||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
|
||||
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverConn)
|
||||
return tunnel.NewClientTunnel(cfg.TunnelName, cfg.TunnelTarget, serverURL, serverConn)
|
||||
}
|
||||
|
@ -135,24 +135,31 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
}
|
||||
|
||||
// Extract Subdomain
|
||||
subdomain := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||
if strings.Count(subdomain, ".") != 0 {
|
||||
tunnelName := strings.TrimSuffix(strings.Replace(r.Host, s.host, "", 1), ".")
|
||||
if strings.Count(tunnelName, ".") != 0 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = fmt.Fprintf(w, "cannot tunnel nested subdomains: %s", r.Host)
|
||||
return
|
||||
}
|
||||
|
||||
// Get True Host
|
||||
remoteHost := conn.RemoteAddr().String()
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
remoteHost = xff
|
||||
}
|
||||
r.RemoteAddr = remoteHost
|
||||
|
||||
// Handle Control Endpoints
|
||||
if subdomain == "" {
|
||||
if tunnelName == "" {
|
||||
s.handleAsHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Tunnels
|
||||
conduitTunnel, exists := s.tunnels.Get(subdomain)
|
||||
conduitTunnel, exists := s.tunnels.Get(tunnelName)
|
||||
if !exists {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", subdomain)
|
||||
_, _ = fmt.Fprintf(w, "unknown tunnel: %s", tunnelName)
|
||||
return
|
||||
}
|
||||
|
||||
@ -165,8 +172,8 @@ func (s *Server) handleRawConnection(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("relaying %s to tunnel", subdomain)
|
||||
_ = conduitTunnel.StartStream(streamID)
|
||||
log.Infof("tunnel %q connection from %s", tunnelName, r.RemoteAddr)
|
||||
_ = conduitTunnel.StartStream(streamID, r.RemoteAddr)
|
||||
}
|
||||
|
||||
func (s *Server) handleAsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@ -215,7 +222,7 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
// Create Tunnel
|
||||
conduitTunnel := tunnel.NewServerTunnel(tunnelName, wsConn)
|
||||
s.tunnels.Set(tunnelName, conduitTunnel)
|
||||
log.Infof("tunnel established: %s", tunnelName)
|
||||
log.Infof("tunnel %q created from %s", tunnelName, r.RemoteAddr)
|
||||
|
||||
// Start Tunnel - This is blocking
|
||||
conduitTunnel.Start()
|
||||
@ -223,5 +230,5 @@ func (s *Server) createTunnel(w http.ResponseWriter, r *http.Request) {
|
||||
// Cleanup Tunnel
|
||||
s.tunnels.Delete(tunnelName)
|
||||
_ = wsConn.Close()
|
||||
log.Infof("tunnel closed: %s", tunnelName)
|
||||
log.Infof("tunnel %q closed from %s", tunnelName, r.RemoteAddr)
|
||||
}
|
||||
|
@ -23,22 +23,28 @@ func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientTunnel(name, target string, wsConn *websocket.Conn) (*Tunnel, error) {
|
||||
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
|
||||
// Get Target URL
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive Conduit URL
|
||||
conduitURL := *serverURL
|
||||
conduitURL.Host = name + "." + conduitURL.Host
|
||||
|
||||
// Get Connection Builder
|
||||
var connBuilder ConnBuilder
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
log.Infof("creating HTTP tunnel: %s -> %s", name, target)
|
||||
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
|
||||
connBuilder, err = HTTPConnectionBuilder(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
log.Infof("creating TCP tunnel: %s -> %s", name, target)
|
||||
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
|
||||
connBuilder = func() (conn io.ReadWriteCloser, err error) {
|
||||
return net.Dial("tcp", target)
|
||||
}
|
||||
@ -109,7 +115,7 @@ func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
go t.StartStream(streamID)
|
||||
go t.StartStream(streamID, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -121,7 +127,7 @@ func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(streamID string) error {
|
||||
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
|
||||
// Get Stream
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
@ -133,6 +139,7 @@ func (t *Tunnel) StartStream(streamID string) error {
|
||||
_ = t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeClose,
|
||||
StreamID: streamID,
|
||||
SourceAddr: sourceAddr,
|
||||
})
|
||||
|
||||
t.CloseStream(streamID)
|
||||
@ -148,8 +155,9 @@ func (t *Tunnel) StartStream(streamID string) error {
|
||||
|
||||
if err := t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
Data: buffer[:n],
|
||||
StreamID: streamID,
|
||||
Data: buffer[:n],
|
||||
SourceAddr: sourceAddr,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -10,5 +10,6 @@ const (
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
StreamID string `json:"stream_id"`
|
||||
SourceAddr string `json:"source_addr"`
|
||||
Data []byte `json:"data,omitempty"`
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user