diff --git a/protocol/mysql/inbound.go b/protocol/mysql/inbound.go index 0d94e56c..dbf9a98e 100644 --- a/protocol/mysql/inbound.go +++ b/protocol/mysql/inbound.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-box/common/listener" "github.com/sagernet/sing-box/common/mux" boxTLS "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/common/uot" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -44,7 +45,7 @@ type Inbound struct { func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) { inbound := &Inbound{ Adapter: inbound.NewAdapter(C.TypeMySQL, tag), - router: router, + router: uot.NewRouter(router, logger), logger: logger, identityProvider: server.NewInMemoryProvider(), } @@ -197,7 +198,7 @@ func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M. h.router.RouteConnectionEx(ctx, conn, metadata, nil) case commandUDP: metadata.Destination = destination - h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + h.logger.InfoContext(ctx, "inbound UoT packet connection to ", metadata.Destination) h.router.RouteConnectionEx(ctx, conn, metadata, nil) default: return E.New("unknown command ", command) diff --git a/protocol/mysql/outbound.go b/protocol/mysql/outbound.go index c30bdd28..4b60d148 100644 --- a/protocol/mysql/outbound.go +++ b/protocol/mysql/outbound.go @@ -14,11 +14,11 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/uot" "github.com/sagernet/smux" "github.com/go-mysql-org/go-mysql/client" @@ -51,6 +51,13 @@ type muxSession struct { conn net.Conn } +func closeMuxSession(entry *muxSession) { + if entry == nil { + return + } + _ = common.Close(entry.session, entry.conn) +} + func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLOutboundOptions) (adapter.Outbound, error) { outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) if err != nil { @@ -153,7 +160,7 @@ func (h *Outbound) getSession(index int) (*smux.Session, error) { return entry.session, nil } if entry != nil { - entry.conn.Close() + closeMuxSession(entry) h.sessions[index] = nil } @@ -171,7 +178,7 @@ func (h *Outbound) getSession(index int) (*smux.Session, error) { h.sessions[index] = nil } h.sessionAccess.Unlock() - conn.Close() + _ = common.Close(session, conn) }(index, entry.session, entry.conn) return entry.session, nil @@ -183,7 +190,7 @@ func (h *Outbound) invalidateSession(index int, session *smux.Session) { if current := h.sessions[index]; current != nil && current.session == session { h.sessions[index] = nil - current.conn.Close() + closeMuxSession(current) } } @@ -234,24 +241,30 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination h.logger.InfoContext(ctx, "outbound connection to ", destination) return h.openStream(ctx, commandTCP, destination) case N.NetworkUDP: - h.logger.InfoContext(ctx, "outbound packet connection to ", destination) - conn, err := h.openStream(ctx, commandUDP, destination) + h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) + conn, err := h.openStream(ctx, commandUDP, uot.RequestDestination(uot.Version)) if err != nil { return nil, err } - return bufio.NewBindPacketConn(&packetConn{conn}, destination), nil + return uot.NewLazyConn(conn, uot.Request{ + IsConnect: true, + Destination: destination, + }), nil default: return nil, E.New("unsupported network: ", network) } } func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - h.logger.InfoContext(ctx, "outbound packet connection to ", destination) - conn, err := h.openStream(ctx, commandUDP, destination) + h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) + conn, err := h.openStream(ctx, commandUDP, uot.RequestDestination(uot.Version)) if err != nil { return nil, err } - return &packetConn{conn}, nil + return uot.NewLazyConn(conn, uot.Request{ + IsConnect: false, + Destination: destination, + }), nil } func (h *Outbound) InterfaceUpdated() { @@ -280,17 +293,3 @@ func (h *Outbound) Close() error { } return err } - -// packetConn wraps a net.Conn as a net.PacketConn for UDP-over-TCP -type packetConn struct { - net.Conn -} - -func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Conn.Read(p) - return -} - -func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Conn.Write(p) -}