refactor: Enhance MySQL inbound handling with identity provider support
This commit is contained in:
@@ -33,22 +33,23 @@ var _ adapter.TCPInjectableInbound = (*Inbound)(nil)
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
tlsConfig boxTLS.ServerConfig
|
||||
username string
|
||||
password string
|
||||
mysqlServer *server.Server
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
tlsConfig boxTLS.ServerConfig
|
||||
identityProvider *server.InMemoryProvider
|
||||
mysqlServer *server.Server
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) {
|
||||
inbound := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
|
||||
router: router,
|
||||
logger: logger,
|
||||
username: options.Username,
|
||||
password: options.Password,
|
||||
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
|
||||
router: router,
|
||||
logger: logger,
|
||||
identityProvider: server.NewInMemoryProvider(),
|
||||
}
|
||||
for _, user := range options.Users {
|
||||
inbound.identityProvider.AddUser(user.User, user.Password)
|
||||
}
|
||||
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
@@ -112,12 +113,8 @@ func (h *Inbound) Close() error {
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
// Create credential provider for this connection
|
||||
provider := server.NewInMemoryProvider()
|
||||
provider.AddUser(h.username, h.password)
|
||||
|
||||
// Use go-mysql server to perform the MySQL handshake (which negotiates TLS)
|
||||
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, provider, &emptyHandler{})
|
||||
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, h.identityProvider, &emptyHandler{})
|
||||
if err != nil {
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source, ": MySQL handshake"))
|
||||
@@ -131,13 +128,13 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
|
||||
h.logger.InfoContext(ctx, "MySQL handshake completed from ", metadata.Source)
|
||||
|
||||
// Handle smux session over the TLS-encrypted connection
|
||||
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose)
|
||||
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose, mysqlConn.GetUser())
|
||||
if err != nil && !E.IsClosed(err) {
|
||||
h.logger.ErrorContext(ctx, E.Cause(err, "process mux session from ", metadata.Source))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
|
||||
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc, user string) error {
|
||||
session, err := smux.Server(conn, smuxConfig())
|
||||
if err != nil {
|
||||
if onClose != nil {
|
||||
@@ -152,7 +149,7 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
|
||||
if sErr != nil {
|
||||
return sErr
|
||||
}
|
||||
go h.handleMuxStream(ctx, stream, source)
|
||||
go h.handleMuxStream(ctx, stream, source, user)
|
||||
}
|
||||
})
|
||||
group.Cleanup(func() {
|
||||
@@ -164,14 +161,14 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
|
||||
return group.Run(ctx)
|
||||
}
|
||||
|
||||
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr) {
|
||||
err := h.handleMuxStream0(ctx, conn, source)
|
||||
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) {
|
||||
err := h.handleMuxStream0(ctx, conn, source, user)
|
||||
if err != nil {
|
||||
h.logger.ErrorContext(ctx, E.Cause(err, "process mux stream"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr) error {
|
||||
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) error {
|
||||
// Read destination from the stream header:
|
||||
// 1 byte command (0x01=TCP, 0x03=UDP)
|
||||
// then socks address (using SocksaddrSerializer)
|
||||
@@ -191,6 +188,7 @@ func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.
|
||||
metadata.Inbound = h.Tag()
|
||||
metadata.InboundType = h.Type()
|
||||
metadata.Source = source
|
||||
metadata.User = user
|
||||
|
||||
switch command {
|
||||
case commandTCP:
|
||||
|
||||
Reference in New Issue
Block a user