From 4d3190480d31761d6d9b68c1a29d635faebb5edf Mon Sep 17 00:00:00 2001 From: n3t1zen Date: Tue, 24 Feb 2026 12:50:25 +0800 Subject: [PATCH] refactor: Enhance MySQL inbound handling with identity provider support --- option/mysql.go | 8 +++++-- protocol/mysql/inbound.go | 44 +++++++++++++++++++-------------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/option/mysql.go b/option/mysql.go index 7d0e0ee3..6467bf3e 100644 --- a/option/mysql.go +++ b/option/mysql.go @@ -3,11 +3,15 @@ package option type MySQLInboundOptions struct { ListenOptions InboundTLSOptionsContainer - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` + Users []MySQLUser `json:"users,omitempty"` Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"` } +type MySQLUser struct { + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` +} + type MySQLOutboundOptions struct { DialerOptions ServerOptions diff --git a/protocol/mysql/inbound.go b/protocol/mysql/inbound.go index f961a41b..0d94e56c 100644 --- a/protocol/mysql/inbound.go +++ b/protocol/mysql/inbound.go @@ -33,22 +33,23 @@ var _ adapter.TCPInjectableInbound = (*Inbound)(nil) type Inbound struct { inbound.Adapter - router adapter.ConnectionRouterEx - logger logger.ContextLogger - listener *listener.Listener - tlsConfig boxTLS.ServerConfig - username string - password string - mysqlServer *server.Server + router adapter.ConnectionRouterEx + logger logger.ContextLogger + listener *listener.Listener + tlsConfig boxTLS.ServerConfig + identityProvider *server.InMemoryProvider + mysqlServer *server.Server } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) { inbound := &Inbound{ - Adapter: inbound.NewAdapter(C.TypeMySQL, tag), - router: router, - logger: logger, - username: options.Username, - password: options.Password, + Adapter: inbound.NewAdapter(C.TypeMySQL, tag), + router: router, + logger: logger, + identityProvider: server.NewInMemoryProvider(), + } + for _, user := range options.Users { + inbound.identityProvider.AddUser(user.User, user.Password) } if options.TLS == nil || !options.TLS.Enabled { @@ -112,12 +113,8 @@ func (h *Inbound) Close() error { } func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { - // Create credential provider for this connection - provider := server.NewInMemoryProvider() - provider.AddUser(h.username, h.password) - // Use go-mysql server to perform the MySQL handshake (which negotiates TLS) - mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, provider, &emptyHandler{}) + mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, h.identityProvider, &emptyHandler{}) if err != nil { N.CloseOnHandshakeFailure(conn, onClose, err) h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source, ": MySQL handshake")) @@ -131,13 +128,13 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a h.logger.InfoContext(ctx, "MySQL handshake completed from ", metadata.Source) // Handle smux session over the TLS-encrypted connection - err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose) + err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose, mysqlConn.GetUser()) if err != nil && !E.IsClosed(err) { h.logger.ErrorContext(ctx, E.Cause(err, "process mux session from ", metadata.Source)) } } -func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error { +func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc, user string) error { session, err := smux.Server(conn, smuxConfig()) if err != nil { if onClose != nil { @@ -152,7 +149,7 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M. if sErr != nil { return sErr } - go h.handleMuxStream(ctx, stream, source) + go h.handleMuxStream(ctx, stream, source, user) } }) group.Cleanup(func() { @@ -164,14 +161,14 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M. return group.Run(ctx) } -func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr) { - err := h.handleMuxStream0(ctx, conn, source) +func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) { + err := h.handleMuxStream0(ctx, conn, source, user) if err != nil { h.logger.ErrorContext(ctx, E.Cause(err, "process mux stream")) } } -func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr) error { +func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) error { // Read destination from the stream header: // 1 byte command (0x01=TCP, 0x03=UDP) // then socks address (using SocksaddrSerializer) @@ -191,6 +188,7 @@ func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M. metadata.Inbound = h.Tag() metadata.InboundType = h.Type() metadata.Source = source + metadata.User = user switch command { case commandTCP: