refactor: Enhance MySQL inbound handling with identity provider support

This commit is contained in:
n3t1zen
2026-02-24 12:50:25 +08:00
parent bab5784ce5
commit 4d3190480d
2 changed files with 27 additions and 25 deletions

View File

@@ -3,11 +3,15 @@ package option
type MySQLInboundOptions struct {
ListenOptions
InboundTLSOptionsContainer
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Users []MySQLUser `json:"users,omitempty"`
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
}
type MySQLUser struct {
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
}
type MySQLOutboundOptions struct {
DialerOptions
ServerOptions

View File

@@ -33,22 +33,23 @@ var _ adapter.TCPInjectableInbound = (*Inbound)(nil)
type Inbound struct {
inbound.Adapter
router adapter.ConnectionRouterEx
logger logger.ContextLogger
listener *listener.Listener
tlsConfig boxTLS.ServerConfig
username string
password string
mysqlServer *server.Server
router adapter.ConnectionRouterEx
logger logger.ContextLogger
listener *listener.Listener
tlsConfig boxTLS.ServerConfig
identityProvider *server.InMemoryProvider
mysqlServer *server.Server
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) {
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
router: router,
logger: logger,
username: options.Username,
password: options.Password,
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
router: router,
logger: logger,
identityProvider: server.NewInMemoryProvider(),
}
for _, user := range options.Users {
inbound.identityProvider.AddUser(user.User, user.Password)
}
if options.TLS == nil || !options.TLS.Enabled {
@@ -112,12 +113,8 @@ func (h *Inbound) Close() error {
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
// Create credential provider for this connection
provider := server.NewInMemoryProvider()
provider.AddUser(h.username, h.password)
// Use go-mysql server to perform the MySQL handshake (which negotiates TLS)
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, provider, &emptyHandler{})
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, h.identityProvider, &emptyHandler{})
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source, ": MySQL handshake"))
@@ -131,13 +128,13 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
h.logger.InfoContext(ctx, "MySQL handshake completed from ", metadata.Source)
// Handle smux session over the TLS-encrypted connection
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose)
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose, mysqlConn.GetUser())
if err != nil && !E.IsClosed(err) {
h.logger.ErrorContext(ctx, E.Cause(err, "process mux session from ", metadata.Source))
}
}
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc, user string) error {
session, err := smux.Server(conn, smuxConfig())
if err != nil {
if onClose != nil {
@@ -152,7 +149,7 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
if sErr != nil {
return sErr
}
go h.handleMuxStream(ctx, stream, source)
go h.handleMuxStream(ctx, stream, source, user)
}
})
group.Cleanup(func() {
@@ -164,14 +161,14 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
return group.Run(ctx)
}
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr) {
err := h.handleMuxStream0(ctx, conn, source)
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) {
err := h.handleMuxStream0(ctx, conn, source, user)
if err != nil {
h.logger.ErrorContext(ctx, E.Cause(err, "process mux stream"))
}
}
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr) error {
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) error {
// Read destination from the stream header:
// 1 byte command (0x01=TCP, 0x03=UDP)
// then socks address (using SocksaddrSerializer)
@@ -191,6 +188,7 @@ func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
metadata.Source = source
metadata.User = user
switch command {
case commandTCP: