package mysql import ( "context" "crypto/tls" "net" "sync" "sync/atomic" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/uot" "github.com/sagernet/smux" "github.com/go-mysql-org/go-mysql/client" ) func RegisterOutbound(registry *outbound.Registry) { outbound.Register[option.MySQLOutboundOptions](registry, C.TypeMySQL, NewOutbound) } var _ adapter.InterfaceUpdateListener = (*Outbound)(nil) type Outbound struct { outbound.Adapter ctx context.Context logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr username string password string tlsConfig *tls.Config maxConnections int nextSession uint32 sessionAccess sync.Mutex sessions []*muxSession } type muxSession struct { session *smux.Session conn net.Conn } func closeMuxSession(entry *muxSession) { if entry == nil { return } _ = common.Close(entry.session, entry.conn) } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLOutboundOptions) (adapter.Outbound, error) { outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) if err != nil { return nil, err } outbound := &Outbound{ Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMySQL, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), ctx: ctx, logger: logger, dialer: outboundDialer, serverAddr: options.ServerOptions.Build(), username: options.Username, password: options.Password, maxConnections: 1, } if options.Multiplex != nil && options.Multiplex.Enabled && options.Multiplex.MaxConnections > 1 { outbound.maxConnections = options.Multiplex.MaxConnections } outbound.sessions = make([]*muxSession, outbound.maxConnections) if outbound.serverAddr.Port == 0 { outbound.serverAddr.Port = 3306 } if outbound.username == "" { outbound.username = "root" } // Build TLS config for MySQL client handshake if options.TLS != nil && options.TLS.Enabled { outbound.tlsConfig = &tls.Config{ InsecureSkipVerify: options.TLS.Insecure, ServerName: options.TLS.ServerName, } if outbound.tlsConfig.ServerName == "" { outbound.tlsConfig.ServerName = options.Server } } else { // Default: use insecure TLS (since this is for tunneling, not real MySQL) outbound.tlsConfig = &tls.Config{ InsecureSkipVerify: true, } } return outbound, nil } func (h *Outbound) createSession() (*muxSession, error) { h.logger.InfoContext(h.ctx, "creating smux session") // Dial TCP connection to server conn, err := h.dialer.DialContext(h.ctx, N.NetworkTCP, h.serverAddr) if err != nil { return nil, E.Cause(err, "dial server") } // Perform MySQL handshake with TLS mysqlConn, err := client.ConnectWithDialer( h.ctx, "tcp", h.serverAddr.String(), h.username, h.password, "", func(ctx context.Context, network, address string) (net.Conn, error) { // Return the already-established connection return conn, nil }, func(c *client.Conn) error { c.SetTLSConfig(h.tlsConfig) return nil }, ) if err != nil { conn.Close() return nil, E.Cause(err, "MySQL handshake") } // After MySQL handshake, the underlying connection is TLS-encrypted. // Get the underlying net.Conn. tlsConn := mysqlConn.Conn.Conn // Create smux session over the TLS connection session, err := smux.Client(tlsConn, smuxConfig()) if err != nil { tlsConn.Close() return nil, E.Cause(err, "create mux session") } return &muxSession{session: session, conn: tlsConn}, nil } func (h *Outbound) getSession(index int) (*smux.Session, error) { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() entry := h.sessions[index] if entry != nil && !entry.session.IsClosed() { return entry.session, nil } if entry != nil { closeMuxSession(entry) h.sessions[index] = nil } entry, err := h.createSession() if err != nil { return nil, err } h.sessions[index] = entry go func(index int, session *smux.Session, conn net.Conn) { // When session is closed, clean up <-session.CloseChan() h.sessionAccess.Lock() if current := h.sessions[index]; current != nil && current.session == session { h.sessions[index] = nil } h.sessionAccess.Unlock() _ = common.Close(session, conn) }(index, entry.session, entry.conn) return entry.session, nil } func (h *Outbound) invalidateSession(index int, session *smux.Session) { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() if current := h.sessions[index]; current != nil && current.session == session { h.sessions[index] = nil closeMuxSession(current) } } func (h *Outbound) openStream(ctx context.Context, command byte, destination M.Socksaddr) (net.Conn, error) { _ = ctx start := int(atomic.AddUint32(&h.nextSession, 1)-1) % h.maxConnections var lastErr error for i := 0; i < h.maxConnections; i++ { index := (start + i) % h.maxConnections session, err := h.getSession(index) if err != nil { lastErr = err continue } stream, err := session.OpenStream() if err != nil { h.invalidateSession(index, session) lastErr = err continue } // Write stream header: command + destination _, err = stream.Write([]byte{command}) if err != nil { stream.Close() lastErr = E.Cause(err, "write stream header command") continue } err = M.SocksaddrSerializer.WriteAddrPort(stream, destination) if err != nil { stream.Close() lastErr = E.Cause(err, "write stream header destination") continue } return stream, nil } if lastErr == nil { lastErr = E.New("open mux stream") } return nil, E.Cause(lastErr, "open mux stream") } func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch N.NetworkName(network) { case N.NetworkTCP: h.logger.InfoContext(ctx, "outbound connection to ", destination) return h.openStream(ctx, commandTCP, destination) case N.NetworkUDP: h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) conn, err := h.openStream(ctx, commandUDP, uot.RequestDestination(uot.Version)) if err != nil { return nil, err } return uot.NewLazyConn(conn, uot.Request{ IsConnect: true, Destination: destination, }), nil default: return nil, E.New("unsupported network: ", network) } } func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) conn, err := h.openStream(ctx, commandUDP, uot.RequestDestination(uot.Version)) if err != nil { return nil, err } return uot.NewLazyConn(conn, uot.Request{ IsConnect: false, Destination: destination, }), nil } func (h *Outbound) InterfaceUpdated() { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() for i, session := range h.sessions { if session == nil { continue } session.session.Close() session.conn.Close() h.sessions[i] = nil } } func (h *Outbound) Close() error { h.sessionAccess.Lock() defer h.sessionAccess.Unlock() var err error for i, session := range h.sessions { if session == nil { continue } err = common.Close(session.session, session.conn) h.sessions[i] = nil } return err }