clean
This commit is contained in:
parent
7f8fb011ce
commit
4ba4fe381f
19
cmd/link.go
19
cmd/link.go
@ -9,14 +9,9 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"reichard.io/conduit/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
StreamID string `json:"stream_id"`
|
|
||||||
Data []byte `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var serverAddr string
|
var serverAddr string
|
||||||
|
|
||||||
var linkCmd = &cobra.Command{
|
var linkCmd = &cobra.Command{
|
||||||
@ -69,7 +64,7 @@ func NewTunnel(tunnelName, tunnelTarget, serverAddress string) (*Tunnel, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect Server WS
|
// Connect Server WS
|
||||||
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?vhost=%s", wsScheme, serverURL.Host, tunnelName)
|
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s", wsScheme, serverURL.Host, tunnelName)
|
||||||
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||||
@ -100,7 +95,7 @@ func (t *Tunnel) Start() error {
|
|||||||
// Handle Messages
|
// Handle Messages
|
||||||
for {
|
for {
|
||||||
// Read Message
|
// Read Message
|
||||||
var msg TunnelMessage
|
var msg types.Message
|
||||||
err := t.serverConn.ReadJSON(&msg)
|
err := t.serverConn.ReadJSON(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error reading from tunnel: %v", err)
|
log.Errorf("Error reading from tunnel: %v", err)
|
||||||
@ -108,7 +103,7 @@ func (t *Tunnel) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case "data":
|
case types.MessageTypeData:
|
||||||
localConn, err := t.getLocalConn(msg.StreamID)
|
localConn, err := t.getLocalConn(msg.StreamID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get local connection: %v", err)
|
log.Errorf("Failed to get local connection: %v", err)
|
||||||
@ -124,7 +119,7 @@ func (t *Tunnel) Start() error {
|
|||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
case "close":
|
case types.MessageTypeClose:
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if localConn, exists := t.localConns[msg.StreamID]; exists {
|
if localConn, exists := t.localConns[msg.StreamID]; exists {
|
||||||
localConn.Close()
|
localConn.Close()
|
||||||
@ -176,8 +171,8 @@ func (t *Tunnel) startResponseRelay(streamID string, lConn net.Conn) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
response := TunnelMessage{
|
response := types.Message{
|
||||||
Type: "data",
|
Type: types.MessageTypeData,
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
Data: buffer[:n],
|
Data: buffer[:n],
|
||||||
}
|
}
|
||||||
|
1
config/config.go
Normal file
1
config/config.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package config
|
@ -12,20 +12,15 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"reichard.io/conduit/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelConnection struct {
|
type TunnelConnection struct {
|
||||||
*websocket.Conn
|
*websocket.Conn
|
||||||
vhost string
|
name string
|
||||||
streams map[string]chan []byte // StreamID -> data channel
|
streams map[string]chan []byte // StreamID -> data channel
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunnelMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
StreamID string `json:"stream_id"`
|
|
||||||
Data []byte `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
tunnels map[string]*TunnelConnection
|
tunnels map[string]*TunnelConnection
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@ -77,7 +72,7 @@ func (s *Server) extractSubdomain(host string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleStatus(conn net.Conn) {
|
func (s *Server) getStatus(conn net.Conn) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
count := len(s.tunnels)
|
count := len(s.tunnels)
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
@ -113,7 +108,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
|||||||
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
streamID := fmt.Sprintf("stream_%d", time.Now().UnixNano())
|
||||||
|
|
||||||
// Send initial data with stream ID
|
// Send initial data with stream ID
|
||||||
msg := TunnelMessage{
|
msg := types.Message{
|
||||||
Type: "data",
|
Type: "data",
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
Data: initialData,
|
Data: initialData,
|
||||||
@ -143,7 +138,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Send close message
|
// Send close message
|
||||||
closeMsg := TunnelMessage{
|
closeMsg := types.Message{
|
||||||
Type: "close",
|
Type: "close",
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
}
|
}
|
||||||
@ -159,7 +154,7 @@ func (s *Server) proxyRawConnection(clientConn net.Conn, tunnelConn *TunnelConne
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := TunnelMessage{
|
msg := types.Message{
|
||||||
Type: "data",
|
Type: "data",
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
Data: buffer[:n],
|
Data: buffer[:n],
|
||||||
@ -215,25 +210,24 @@ func (s *Server) handleAsHTTP(conn net.Conn, initialData []byte) {
|
|||||||
reader := bufio.NewReader(bytes.NewReader(initialData))
|
reader := bufio.NewReader(bytes.NewReader(initialData))
|
||||||
req, err := http.ReadRequest(reader)
|
req, err := http.ReadRequest(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Control Endpoints
|
// Handle Control Endpoints
|
||||||
switch req.URL.Path {
|
switch req.URL.Path {
|
||||||
case "/_conduit/tunnel":
|
case "/_conduit/tunnel":
|
||||||
s.handleTunnelUpgrade(conn, req)
|
s.createTunnel(conn, req)
|
||||||
return
|
|
||||||
case "/_conduit/status":
|
case "/_conduit/status":
|
||||||
s.handleStatus(conn)
|
s.getStatus(conn)
|
||||||
default:
|
default:
|
||||||
conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n"))
|
_, _ = conn.Write([]byte("HTTP/1.1 404 Not Found\r\n\r\n"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
||||||
for {
|
for {
|
||||||
var msg TunnelMessage
|
var msg types.Message
|
||||||
err := tunnel.ReadJSON(&msg)
|
err := tunnel.ReadJSON(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
@ -253,84 +247,85 @@ func (s *Server) handleTunnelMessages(tunnel *TunnelConnection) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (s *Server) handleTunnelUpgrade(conn net.Conn, req *http.Request) {
|
func (s *Server) createTunnel(conn net.Conn, req *http.Request) {
|
||||||
vhost := req.URL.Query().Get("vhost")
|
// Get Tunnel Name
|
||||||
if vhost == "" {
|
tunnelName := req.URL.Query().Get("tunnelName")
|
||||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing vhost parameter"))
|
if tunnelName == "" {
|
||||||
|
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\nMissing tunnelName parameter"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a fake ResponseWriter that writes to our raw connection
|
// Validate Unique
|
||||||
fakeWriter := &fakeResponseWriter{conn: conn}
|
if _, exists := s.tunnels[tunnelName]; exists {
|
||||||
|
conn.Write([]byte("HTTP/1.1 409 Conflict\r\n\r\nTunnel already registered"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Use the upgrader
|
// Upgrade Connection
|
||||||
wsConn, err := s.upgrader.Upgrade(fakeWriter, req, nil)
|
wsConn, err := s.upgrader.Upgrade(&rawResponseWriter{conn: conn}, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("WebSocket upgrade failed: %v", err)
|
log.Errorf("WebSocket upgrade failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TunnelConnection
|
// Create & Cache TunnelConnection
|
||||||
tunnel := &TunnelConnection{
|
tunnel := &TunnelConnection{
|
||||||
Conn: wsConn,
|
Conn: wsConn,
|
||||||
vhost: vhost,
|
name: tunnelName,
|
||||||
streams: make(map[string]chan []byte),
|
streams: make(map[string]chan []byte),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.tunnels[vhost] = tunnel
|
s.tunnels[tunnelName] = tunnel
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
log.Infof("Tunnel established: %s", tunnelName)
|
||||||
log.Infof("Tunnel established: %s", vhost)
|
|
||||||
|
|
||||||
// Keep connection alive and handle cleanup
|
// Keep connection alive and handle cleanup
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
delete(s.tunnels, vhost)
|
delete(s.tunnels, tunnelName)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
wsConn.Close()
|
wsConn.Close()
|
||||||
log.Infof("Tunnel closed: %s", vhost)
|
log.Infof("Tunnel closed: %s", tunnelName)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Handle tunnel messages
|
// Handle tunnel messages
|
||||||
s.handleTunnelMessages(tunnel)
|
s.handleTunnelMessages(tunnel)
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeResponseWriter struct {
|
type rawResponseWriter struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
header http.Header
|
header http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeResponseWriter) Header() http.Header {
|
func (f *rawResponseWriter) Header() http.Header {
|
||||||
if f.header == nil {
|
if f.header == nil {
|
||||||
f.header = make(http.Header)
|
f.header = make(http.Header)
|
||||||
}
|
}
|
||||||
return f.header
|
return f.header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeResponseWriter) Write(data []byte) (int, error) {
|
func (f *rawResponseWriter) Write(data []byte) (int, error) {
|
||||||
return f.conn.Write(data)
|
return f.conn.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fakeResponseWriter) WriteHeader(statusCode int) {
|
func (f *rawResponseWriter) WriteHeader(statusCode int) {
|
||||||
// Write HTTP status line
|
// Write Status
|
||||||
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
|
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
|
||||||
f.conn.Write([]byte(status))
|
_, _ = f.conn.Write([]byte(status))
|
||||||
|
|
||||||
// Write headers
|
// Write Headers
|
||||||
for key, values := range f.header {
|
for key, values := range f.header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value)))
|
_, _ = fmt.Fprintf(f.conn, "%s: %s\r\n", key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// End headers
|
// End Headers
|
||||||
f.conn.Write([]byte("\r\n"))
|
_, _ = f.conn.Write([]byte("\r\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement http.Hijacker
|
func (f *rawResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
func (f *fakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
// Return Raw Connection & ReadWriter
|
||||||
// Return the raw connection and create a ReadWriter for it
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
||||||
return f.conn, rw, nil
|
return f.conn, rw, nil
|
||||||
}
|
}
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type tunnelWriter struct {
|
|
||||||
conn net.Conn
|
|
||||||
header http.Header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tunnelWriter) Header() http.Header {
|
|
||||||
if f.header == nil {
|
|
||||||
f.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return f.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tunnelWriter) Write(data []byte) (int, error) {
|
|
||||||
return f.conn.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tunnelWriter) WriteHeader(statusCode int) {
|
|
||||||
// Write HTTP status line
|
|
||||||
status := fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode))
|
|
||||||
f.conn.Write([]byte(status))
|
|
||||||
|
|
||||||
// Write headers
|
|
||||||
for key, values := range f.header {
|
|
||||||
for _, value := range values {
|
|
||||||
f.conn.Write([]byte(fmt.Sprintf("%s: %s\r\n", key, value)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// End headers
|
|
||||||
f.conn.Write([]byte("\r\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *tunnelWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(f.conn), bufio.NewWriter(f.conn))
|
|
||||||
return f.conn, rw, nil
|
|
||||||
}
|
|
14
types/message.go
Normal file
14
types/message.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type MessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageTypeData MessageType = "data"
|
||||||
|
MessageTypeClose MessageType = "close"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
StreamID string `json:"stream_id"`
|
||||||
|
Data []byte `json:"data,omitempty"`
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user