From 03eb444cddadfc91f12817da615b06bc0b1ef59b Mon Sep 17 00:00:00 2001 From: n3t1zen Date: Tue, 24 Feb 2026 16:58:04 +0800 Subject: [PATCH] fix: improve multiplexing for mysql protocol --- protocol/mysql/outbound.go | 198 +++++++++++++++++++++---------------- 1 file changed, 113 insertions(+), 85 deletions(-) diff --git a/protocol/mysql/outbound.go b/protocol/mysql/outbound.go index 89c320e5..c30bdd28 100644 --- a/protocol/mysql/outbound.go +++ b/protocol/mysql/outbound.go @@ -4,8 +4,8 @@ import ( "context" "crypto/tls" "net" - "os" "sync" + "sync/atomic" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" @@ -32,17 +32,23 @@ var _ adapter.InterfaceUpdateListener = (*Outbound)(nil) type Outbound struct { outbound.Adapter - ctx context.Context - logger logger.ContextLogger - dialer N.Dialer - serverAddr M.Socksaddr - username string - password string - tlsConfig *tls.Config + ctx context.Context + logger logger.ContextLogger + dialer N.Dialer + serverAddr M.Socksaddr + username string + password string + tlsConfig *tls.Config + maxConnections int + nextSession uint32 sessionAccess sync.Mutex - session *smux.Session - sessionConn net.Conn + sessions []*muxSession +} + +type muxSession struct { + session *smux.Session + conn net.Conn } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLOutboundOptions) (adapter.Outbound, error) { @@ -52,15 +58,21 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMySQL, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), - ctx: ctx, - logger: logger, - dialer: outboundDialer, - serverAddr: options.ServerOptions.Build(), - username: options.Username, - password: options.Password, + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMySQL, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + ctx: ctx, + logger: logger, + dialer: outboundDialer, + serverAddr: options.ServerOptions.Build(), + username: options.Username, + password: options.Password, + maxConnections: 1, } + if options.Multiplex != nil && options.Multiplex.Enabled && options.Multiplex.MaxConnections > 1 { + outbound.maxConnections = options.Multiplex.MaxConnections + } + outbound.sessions = make([]*muxSession, outbound.maxConnections) + if outbound.serverAddr.Port == 0 { outbound.serverAddr.Port = 3306 } @@ -88,14 +100,8 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return outbound, nil } -func (h *Outbound) getSession() (*smux.Session, error) { - h.sessionAccess.Lock() - defer h.sessionAccess.Unlock() - - if h.session != nil && !h.session.IsClosed() { - return h.session, nil - } - +func (h *Outbound) createSession() (*muxSession, error) { + h.logger.InfoContext(h.ctx, "creating smux session") // Dial TCP connection to server conn, err := h.dialer.DialContext(h.ctx, N.NetworkTCP, h.serverAddr) if err != nil { @@ -135,66 +141,91 @@ func (h *Outbound) getSession() (*smux.Session, error) { return nil, E.Cause(err, "create mux session") } - h.session = session - h.sessionConn = tlsConn - - go func() { - // When session is closed, clean up - <-session.CloseChan() - h.sessionAccess.Lock() - if h.session == session { - h.session = nil - h.sessionConn = nil - } - h.sessionAccess.Unlock() - tlsConn.Close() - }() - - return session, nil + return &muxSession{session: session, conn: tlsConn}, nil } -func (h *Outbound) openStream(ctx context.Context, command byte, destination M.Socksaddr) (net.Conn, error) { - session, err := h.getSession() +func (h *Outbound) getSession(index int) (*smux.Session, error) { + h.sessionAccess.Lock() + defer h.sessionAccess.Unlock() + + entry := h.sessions[index] + if entry != nil && !entry.session.IsClosed() { + return entry.session, nil + } + if entry != nil { + entry.conn.Close() + h.sessions[index] = nil + } + + entry, err := h.createSession() if err != nil { return nil, err } + h.sessions[index] = entry - stream, err := session.OpenStream() - if err != nil { - // Session might be stale, try once more with a new session + go func(index int, session *smux.Session, conn net.Conn) { + // When session is closed, clean up + <-session.CloseChan() h.sessionAccess.Lock() - if h.session == session { - h.session = nil - if h.sessionConn != nil { - h.sessionConn.Close() - h.sessionConn = nil - } + if current := h.sessions[index]; current != nil && current.session == session { + h.sessions[index] = nil } h.sessionAccess.Unlock() + conn.Close() + }(index, entry.session, entry.conn) - session, err = h.getSession() + return entry.session, nil +} + +func (h *Outbound) invalidateSession(index int, session *smux.Session) { + h.sessionAccess.Lock() + defer h.sessionAccess.Unlock() + + if current := h.sessions[index]; current != nil && current.session == session { + h.sessions[index] = nil + current.conn.Close() + } +} + +func (h *Outbound) openStream(ctx context.Context, command byte, destination M.Socksaddr) (net.Conn, error) { + _ = ctx + start := int(atomic.AddUint32(&h.nextSession, 1)-1) % h.maxConnections + var lastErr error + for i := 0; i < h.maxConnections; i++ { + index := (start + i) % h.maxConnections + session, err := h.getSession(index) if err != nil { - return nil, err + lastErr = err + continue } - stream, err = session.OpenStream() + + stream, err := session.OpenStream() if err != nil { - return nil, E.Cause(err, "open mux stream") + h.invalidateSession(index, session) + lastErr = err + continue } - } - // Write stream header: command + destination - _, err = stream.Write([]byte{command}) - if err != nil { - stream.Close() - return nil, E.Cause(err, "write stream header command") - } - err = M.SocksaddrSerializer.WriteAddrPort(stream, destination) - if err != nil { - stream.Close() - return nil, E.Cause(err, "write stream header destination") - } + // Write stream header: command + destination + _, err = stream.Write([]byte{command}) + if err != nil { + stream.Close() + lastErr = E.Cause(err, "write stream header command") + continue + } + err = M.SocksaddrSerializer.WriteAddrPort(stream, destination) + if err != nil { + stream.Close() + lastErr = E.Cause(err, "write stream header destination") + continue + } - return stream, nil + return stream, nil + } + if lastErr == nil { + lastErr = E.New("open mux stream") + } + return nil, E.Cause(lastErr, "open mux stream") } func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { @@ -226,13 +257,13 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n func (h *Outbound) InterfaceUpdated() { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() - if h.session != nil { - h.session.Close() - h.session = nil - } - if h.sessionConn != nil { - h.sessionConn.Close() - h.sessionConn = nil + for i, session := range h.sessions { + if session == nil { + continue + } + session.session.Close() + session.conn.Close() + h.sessions[i] = nil } } @@ -240,13 +271,12 @@ func (h *Outbound) Close() error { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() var err error - if h.session != nil { - err = h.session.Close() - h.session = nil - } - if h.sessionConn != nil { - common.Close(h.sessionConn) - h.sessionConn = nil + for i, session := range h.sessions { + if session == nil { + continue + } + err = common.Close(session.session, session.conn) + h.sessions[i] = nil } return err } @@ -264,5 +294,3 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return c.Conn.Write(p) } - -var _ = os.ErrInvalid // keep import