refactor: Enhance MySQL inbound handling with identity provider support
This commit is contained in:
@@ -3,11 +3,15 @@ package option
|
|||||||
type MySQLInboundOptions struct {
|
type MySQLInboundOptions struct {
|
||||||
ListenOptions
|
ListenOptions
|
||||||
InboundTLSOptionsContainer
|
InboundTLSOptionsContainer
|
||||||
Username string `json:"username,omitempty"`
|
Users []MySQLUser `json:"users,omitempty"`
|
||||||
Password string `json:"password,omitempty"`
|
|
||||||
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
|
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MySQLUser struct {
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type MySQLOutboundOptions struct {
|
type MySQLOutboundOptions struct {
|
||||||
DialerOptions
|
DialerOptions
|
||||||
ServerOptions
|
ServerOptions
|
||||||
|
|||||||
@@ -33,22 +33,23 @@ var _ adapter.TCPInjectableInbound = (*Inbound)(nil)
|
|||||||
|
|
||||||
type Inbound struct {
|
type Inbound struct {
|
||||||
inbound.Adapter
|
inbound.Adapter
|
||||||
router adapter.ConnectionRouterEx
|
router adapter.ConnectionRouterEx
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
listener *listener.Listener
|
listener *listener.Listener
|
||||||
tlsConfig boxTLS.ServerConfig
|
tlsConfig boxTLS.ServerConfig
|
||||||
username string
|
identityProvider *server.InMemoryProvider
|
||||||
password string
|
mysqlServer *server.Server
|
||||||
mysqlServer *server.Server
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) {
|
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MySQLInboundOptions) (adapter.Inbound, error) {
|
||||||
inbound := &Inbound{
|
inbound := &Inbound{
|
||||||
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
|
Adapter: inbound.NewAdapter(C.TypeMySQL, tag),
|
||||||
router: router,
|
router: router,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
username: options.Username,
|
identityProvider: server.NewInMemoryProvider(),
|
||||||
password: options.Password,
|
}
|
||||||
|
for _, user := range options.Users {
|
||||||
|
inbound.identityProvider.AddUser(user.User, user.Password)
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.TLS == nil || !options.TLS.Enabled {
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
@@ -112,12 +113,8 @@ func (h *Inbound) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||||
// Create credential provider for this connection
|
|
||||||
provider := server.NewInMemoryProvider()
|
|
||||||
provider.AddUser(h.username, h.password)
|
|
||||||
|
|
||||||
// Use go-mysql server to perform the MySQL handshake (which negotiates TLS)
|
// Use go-mysql server to perform the MySQL handshake (which negotiates TLS)
|
||||||
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, provider, &emptyHandler{})
|
mysqlConn, err := h.mysqlServer.NewCustomizedConn(conn, h.identityProvider, &emptyHandler{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source, ": MySQL handshake"))
|
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source, ": MySQL handshake"))
|
||||||
@@ -131,13 +128,13 @@ func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata a
|
|||||||
h.logger.InfoContext(ctx, "MySQL handshake completed from ", metadata.Source)
|
h.logger.InfoContext(ctx, "MySQL handshake completed from ", metadata.Source)
|
||||||
|
|
||||||
// Handle smux session over the TLS-encrypted connection
|
// Handle smux session over the TLS-encrypted connection
|
||||||
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose)
|
err = h.handleMuxSession(ctx, tlsConn, metadata.Source, onClose, mysqlConn.GetUser())
|
||||||
if err != nil && !E.IsClosed(err) {
|
if err != nil && !E.IsClosed(err) {
|
||||||
h.logger.ErrorContext(ctx, E.Cause(err, "process mux session from ", metadata.Source))
|
h.logger.ErrorContext(ctx, E.Cause(err, "process mux session from ", metadata.Source))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
|
func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc, user string) error {
|
||||||
session, err := smux.Server(conn, smuxConfig())
|
session, err := smux.Server(conn, smuxConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if onClose != nil {
|
if onClose != nil {
|
||||||
@@ -152,7 +149,7 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
|
|||||||
if sErr != nil {
|
if sErr != nil {
|
||||||
return sErr
|
return sErr
|
||||||
}
|
}
|
||||||
go h.handleMuxStream(ctx, stream, source)
|
go h.handleMuxStream(ctx, stream, source, user)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
group.Cleanup(func() {
|
group.Cleanup(func() {
|
||||||
@@ -164,14 +161,14 @@ func (h *Inbound) handleMuxSession(ctx context.Context, conn net.Conn, source M.
|
|||||||
return group.Run(ctx)
|
return group.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr) {
|
func (h *Inbound) handleMuxStream(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) {
|
||||||
err := h.handleMuxStream0(ctx, conn, source)
|
err := h.handleMuxStream0(ctx, conn, source, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.ErrorContext(ctx, E.Cause(err, "process mux stream"))
|
h.logger.ErrorContext(ctx, E.Cause(err, "process mux stream"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr) error {
|
func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.Socksaddr, user string) error {
|
||||||
// Read destination from the stream header:
|
// Read destination from the stream header:
|
||||||
// 1 byte command (0x01=TCP, 0x03=UDP)
|
// 1 byte command (0x01=TCP, 0x03=UDP)
|
||||||
// then socks address (using SocksaddrSerializer)
|
// then socks address (using SocksaddrSerializer)
|
||||||
@@ -191,6 +188,7 @@ func (h *Inbound) handleMuxStream0(ctx context.Context, conn net.Conn, source M.
|
|||||||
metadata.Inbound = h.Tag()
|
metadata.Inbound = h.Tag()
|
||||||
metadata.InboundType = h.Type()
|
metadata.InboundType = h.Type()
|
||||||
metadata.Source = source
|
metadata.Source = source
|
||||||
|
metadata.User = user
|
||||||
|
|
||||||
switch command {
|
switch command {
|
||||||
case commandTCP:
|
case commandTCP:
|
||||||
|
|||||||
Reference in New Issue
Block a user