chore: tunnel recorder & slight refactor
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
43
tunnel/forwarder.go
Normal file
43
tunnel/forwarder.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
type ForwarderType int
|
||||
|
||||
const (
|
||||
ForwarderTCP ForwarderType = iota
|
||||
ForwarderHTTP
|
||||
)
|
||||
|
||||
type Forwarder interface {
|
||||
Type() ForwarderType
|
||||
Initialize() (Stream, error)
|
||||
Start(context.Context) error
|
||||
}
|
||||
|
||||
func NewForwarder(target string, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||
// Get Target URL
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get Connection Builder
|
||||
var forwarder Forwarder
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
forwarder, err = newHTTPForwarder(targetURL, tunnelStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
forwarder = newTCPForwarder(target, tunnelStore)
|
||||
}
|
||||
|
||||
return forwarder, nil
|
||||
}
|
||||
104
tunnel/http.go
104
tunnel/http.go
@@ -1,104 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func HTTPConnectionBuilder(targetURL *url.URL) (ConnBuilder, error) {
|
||||
multiConnListener := newMultiConnListener()
|
||||
|
||||
// Create Reverse Proxy
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.Host = targetURL.Host
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
// Start HTTP Proxy
|
||||
go func() {
|
||||
defer multiConnListener.Close()
|
||||
_ = http.Serve(multiConnListener, proxy)
|
||||
}()
|
||||
|
||||
// Return Connection Builder
|
||||
return func() (conn io.ReadWriteCloser, err error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
if err := multiConnListener.addConn(serverConn); err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return clientConn, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
type multiConnListener struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newMultiConnListener() *multiConnListener {
|
||||
return &multiConnListener{
|
||||
connCh: make(chan net.Conn, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.connCh:
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
// Drain any remaining connections
|
||||
go func() {
|
||||
for conn := range l.connCh {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
close(l.connCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) addConn(conn net.Conn) error {
|
||||
select {
|
||||
case l.connCh <- conn:
|
||||
return nil
|
||||
case <-l.closed:
|
||||
conn.Close()
|
||||
return fmt.Errorf("listener is closed")
|
||||
default:
|
||||
conn.Close()
|
||||
return fmt.Errorf("connection queue full")
|
||||
}
|
||||
}
|
||||
132
tunnel/http_forwarder.go
Normal file
132
tunnel/http_forwarder.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
func newHTTPForwarder(targetURL *url.URL, tunnelStore store.TunnelStore) (Forwarder, error) {
|
||||
return &httpConnBuilder{
|
||||
multiConnListener: newMultiConnListener(),
|
||||
tunnelStore: tunnelStore,
|
||||
targetURL: targetURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type httpConnBuilder struct {
|
||||
multiConnListener *multiConnListener
|
||||
tunnelStore store.TunnelStore
|
||||
targetURL *url.URL
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Type() ForwarderType {
|
||||
return ForwarderHTTP
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Start(ctx context.Context) error {
|
||||
// Create Reverse Proxy Server
|
||||
server := &http.Server{
|
||||
Handler: &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.Host = c.targetURL.Host
|
||||
req.URL.Host = c.targetURL.Host
|
||||
req.URL.Scheme = c.targetURL.Scheme
|
||||
c.tunnelStore.RecordRequest(req)
|
||||
},
|
||||
ModifyResponse: c.tunnelStore.RecordResponse,
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, fmt.Sprintf("Proxy error: %v", err), http.StatusBadGateway)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Context & Cleanup
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
server.Shutdown(ctx)
|
||||
c.multiConnListener.Close()
|
||||
}()
|
||||
|
||||
// Start HTTP Proxy
|
||||
if err := server.Serve(c.multiConnListener); err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *httpConnBuilder) Initialize() (Stream, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
if err := c.multiConnListener.addConn(serverConn); err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &streamImpl{clientConn, c.targetURL.String()}, nil
|
||||
}
|
||||
|
||||
type multiConnListener struct {
|
||||
connCh chan net.Conn
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newMultiConnListener() *multiConnListener {
|
||||
return &multiConnListener{
|
||||
connCh: make(chan net.Conn, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.connCh:
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
// Drain any remaining connections
|
||||
go func() {
|
||||
for conn := range l.connCh {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
close(l.connCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *multiConnListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
func (l *multiConnListener) addConn(conn net.Conn) error {
|
||||
select {
|
||||
case l.connCh <- conn:
|
||||
return nil
|
||||
case <-l.closed:
|
||||
conn.Close()
|
||||
return fmt.Errorf("listener is closed")
|
||||
default:
|
||||
conn.Close()
|
||||
return fmt.Errorf("connection queue full")
|
||||
}
|
||||
}
|
||||
23
tunnel/name.go
Normal file
23
tunnel/name.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
var colors = []string{
|
||||
"red", "blue", "green", "yellow", "purple", "orange",
|
||||
"pink", "brown", "black", "white", "gray", "cyan",
|
||||
}
|
||||
|
||||
var animals = []string{
|
||||
"cat", "dog", "bird", "fish", "lion", "tiger",
|
||||
"bear", "wolf", "fox", "deer", "rabbit", "mouse",
|
||||
}
|
||||
|
||||
func generateTunnelName() string {
|
||||
color := colors[rand.Intn(len(colors))]
|
||||
animal := animals[rand.Intn(len(animals))]
|
||||
number := rand.Intn(900) + 100
|
||||
return fmt.Sprintf("%s-%s-%d", color, animal, number)
|
||||
}
|
||||
26
tunnel/stream.go
Normal file
26
tunnel/stream.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var _ Stream = (*streamImpl)(nil)
|
||||
|
||||
type Stream interface {
|
||||
io.ReadWriteCloser
|
||||
Source() string
|
||||
}
|
||||
|
||||
func NewStream(conn net.Conn, source string) Stream {
|
||||
return &streamImpl{conn, source}
|
||||
}
|
||||
|
||||
type streamImpl struct {
|
||||
net.Conn
|
||||
source string
|
||||
}
|
||||
|
||||
func (s *streamImpl) Source() string {
|
||||
return s.source
|
||||
}
|
||||
37
tunnel/tcp_forwarder.go
Normal file
37
tunnel/tcp_forwarder.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"reichard.io/conduit/store"
|
||||
)
|
||||
|
||||
func newTCPForwarder(target string, tunnelStore store.TunnelStore) Forwarder {
|
||||
return &tcpConnBuilder{
|
||||
target: target,
|
||||
tunnelStore: tunnelStore,
|
||||
}
|
||||
}
|
||||
|
||||
type tcpConnBuilder struct {
|
||||
target string
|
||||
tunnelStore store.TunnelStore
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Type() ForwarderType {
|
||||
return ForwarderTCP
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Initialize() (Stream, error) {
|
||||
conn, err := net.Dial("tcp", l.target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &streamImpl{conn, l.target}, nil
|
||||
}
|
||||
|
||||
func (l *tcpConnBuilder) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
266
tunnel/tunnel.go
266
tunnel/tunnel.go
@@ -1,76 +1,88 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"reichard.io/conduit/config"
|
||||
"reichard.io/conduit/pkg/maps"
|
||||
"reichard.io/conduit/types"
|
||||
)
|
||||
|
||||
type ConnBuilder func() (conn io.ReadWriteCloser, err error)
|
||||
|
||||
// NewServerTunnel creates a new tunnel with name and websocket connection. The tunnel is
|
||||
// generally instantiated after an upgrade request from the server.
|
||||
func NewServerTunnel(name string, wsConn *websocket.Conn) *Tunnel {
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
streams: maps.New[string, Stream](),
|
||||
wsConn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientTunnel(name, target string, serverURL *url.URL, wsConn *websocket.Conn) (*Tunnel, error) {
|
||||
// Get Target URL
|
||||
targetURL, err := url.Parse(target)
|
||||
// NewClientTunnel creates a new tunnel with the provided configuration and forwarder. A
|
||||
// forwarder is effectively the protocol being forwarded. For example HTTP (Proxy), and TCP.
|
||||
func NewClientTunnel(cfg *config.ClientConfig, forwarder Forwarder) (*Tunnel, error) {
|
||||
// Parse Server URL
|
||||
serverURL, err := url.Parse(cfg.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive Conduit URL
|
||||
conduitURL := *serverURL
|
||||
conduitURL.Host = name + "." + conduitURL.Host
|
||||
|
||||
// Get Connection Builder
|
||||
var connBuilder ConnBuilder
|
||||
switch targetURL.Scheme {
|
||||
case "http", "https":
|
||||
log.Infof("creating HTTP tunnel: %s -> %s", conduitURL.String(), target)
|
||||
connBuilder, err = HTTPConnectionBuilder(targetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Parse Scheme
|
||||
var wsScheme string
|
||||
switch serverURL.Scheme {
|
||||
case "https":
|
||||
wsScheme = "wss"
|
||||
case "http":
|
||||
wsScheme = "ws"
|
||||
default:
|
||||
log.Infof("creating TCP tunnel: %s -> %s", conduitURL.String(), target)
|
||||
connBuilder = func() (conn io.ReadWriteCloser, err error) {
|
||||
return net.Dial("tcp", target)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported scheme: %s", serverURL.Scheme)
|
||||
}
|
||||
|
||||
// Create Tunnel Name
|
||||
if cfg.TunnelName == "" {
|
||||
cfg.TunnelName = generateTunnelName()
|
||||
log.Infof("tunnel name not provided; generated: %s", cfg.TunnelName)
|
||||
}
|
||||
|
||||
// Connect Server WS
|
||||
wsURL := fmt.Sprintf("%s://%s/_conduit/tunnel?tunnelName=%s&apiKey=%s", wsScheme, serverURL.Host, cfg.TunnelName, cfg.APIKey)
|
||||
serverConn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect: %v", err)
|
||||
}
|
||||
|
||||
return &Tunnel{
|
||||
name: name,
|
||||
wsConn: wsConn,
|
||||
streams: maps.New[string, io.ReadWriteCloser](),
|
||||
connBuilder: connBuilder,
|
||||
name: cfg.TunnelName,
|
||||
wsConn: serverConn,
|
||||
streams: maps.New[string, Stream](),
|
||||
forwarder: forwarder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams *maps.Map[string, io.ReadWriteCloser]
|
||||
connBuilder ConnBuilder
|
||||
ctx context.Context
|
||||
name string
|
||||
wsConn *websocket.Conn
|
||||
streams *maps.Map[string, Stream]
|
||||
forwarder Forwarder
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *Tunnel) Start() {
|
||||
func (t *Tunnel) Start(ctx context.Context) {
|
||||
log.Infof("initiated tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
defer log.Infof("closed tunnel %q with %s", t.name, t.wsConn.RemoteAddr().String())
|
||||
|
||||
t.ctx = ctx
|
||||
|
||||
// Start Message Receiver
|
||||
for {
|
||||
var msg types.Message
|
||||
err := t.wsConn.ReadJSON(&msg)
|
||||
msg, err := t.readWSWithContext(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -81,105 +93,57 @@ func (t *Tunnel) Start() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Ensure Stream
|
||||
if err := t.initStreamConnection(msg.StreamID); err != nil {
|
||||
log.WithError(err).Errorf("failed to initialize stream %s connection", t.name)
|
||||
// Get Stream
|
||||
stream, err := t.getStream(msg.StreamID)
|
||||
if err != nil {
|
||||
if msg.Type != types.MessageTypeClose {
|
||||
log.WithError(err).Errorf("failed to get stream %s", msg.StreamID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle Messages
|
||||
switch msg.Type {
|
||||
case types.MessageTypeClose:
|
||||
_ = t.CloseStream(msg.StreamID)
|
||||
_ = t.closeStream(stream, msg.StreamID)
|
||||
case types.MessageTypeData:
|
||||
_ = t.WriteStream(msg.StreamID, msg.Data)
|
||||
_, err = stream.Write(msg.Data)
|
||||
}
|
||||
|
||||
// Log Error
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to handle message %s", msg.StreamID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) initStreamConnection(streamID string) error {
|
||||
if t.connBuilder == nil {
|
||||
return nil
|
||||
func (t *Tunnel) readWSWithContext(ctx context.Context) (*types.Message, error) {
|
||||
type result struct {
|
||||
msg *types.Message
|
||||
err error
|
||||
}
|
||||
|
||||
if _, found := t.streams.Get(streamID); found {
|
||||
return nil
|
||||
}
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
var msg types.Message
|
||||
err := t.wsConn.ReadJSON(&msg)
|
||||
resultChan <- result{&msg, err}
|
||||
}()
|
||||
|
||||
conn, err := t.connBuilder()
|
||||
if err != nil {
|
||||
return err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultChan:
|
||||
return result.msg, result.err
|
||||
}
|
||||
|
||||
if err := t.AddStream(streamID, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go t.StartStream(streamID, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) AddStream(streamID string, conn io.ReadWriteCloser) error {
|
||||
func (t *Tunnel) AddStream(stream Stream, streamID string) error {
|
||||
if t.streams.HasKey(streamID) {
|
||||
return fmt.Errorf("stream %s already exists", streamID)
|
||||
}
|
||||
t.streams.Set(streamID, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(streamID string, sourceAddr string) error {
|
||||
// Get Stream
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
|
||||
// Close Stream
|
||||
defer func() {
|
||||
_ = t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeClose,
|
||||
StreamID: streamID,
|
||||
SourceAddr: sourceAddr,
|
||||
})
|
||||
|
||||
t.CloseStream(streamID)
|
||||
}()
|
||||
|
||||
// Start Stream
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
StreamID: streamID,
|
||||
Data: buffer[:n],
|
||||
SourceAddr: sourceAddr,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) WriteStream(streamID string, data []byte) error {
|
||||
// Get Stream
|
||||
conn, found := t.streams.Get(streamID)
|
||||
if !found {
|
||||
return fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
|
||||
_, err := conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tunnel) CloseStream(streamID string) error {
|
||||
if conn, ok := t.streams.Get(streamID); ok {
|
||||
t.streams.Delete(streamID)
|
||||
return conn.Close()
|
||||
}
|
||||
log.Infof("tunnel %q initiated stream with %s", t.name, stream.Source())
|
||||
t.streams.Set(streamID, stream)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -187,6 +151,78 @@ func (t *Tunnel) Source() string {
|
||||
return t.wsConn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (t *Tunnel) StartStream(stream Stream, streamID string) error {
|
||||
// Close Stream
|
||||
defer t.closeStream(stream, streamID)
|
||||
|
||||
// Start Stream
|
||||
for {
|
||||
data, err := t.readStreamWithContext(t.ctx, stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := t.sendWS(&types.Message{
|
||||
Type: types.MessageTypeData,
|
||||
StreamID: streamID,
|
||||
Data: data,
|
||||
SourceAddr: stream.Source(),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) closeStream(stream Stream, streamID string) error {
|
||||
log.Infof("tunnel %q closed stream with %s", t.name, stream.Source())
|
||||
t.streams.Delete(streamID)
|
||||
return stream.Close()
|
||||
}
|
||||
|
||||
func (t *Tunnel) getStream(streamID string) (Stream, error) {
|
||||
// Check Existing Stream
|
||||
if stream, found := t.streams.Get(streamID); found {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Check Forwarder
|
||||
if t.forwarder == nil {
|
||||
return nil, fmt.Errorf("stream %s does not exist", streamID)
|
||||
}
|
||||
|
||||
// Initialize Forwarder & Add Stream
|
||||
stream, err := t.forwarder.Initialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := t.AddStream(stream, streamID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go t.StartStream(stream, streamID)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) readStreamWithContext(ctx context.Context, stream Stream) ([]byte, error) {
|
||||
type result struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan result, 1)
|
||||
go func() {
|
||||
buffer := make([]byte, 4096)
|
||||
n, err := stream.Read(buffer)
|
||||
resultChan <- result{buffer[:n], err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-resultChan:
|
||||
return result.data, result.err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) sendWS(msg *types.Message) error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user