Improve HTTPS DNS transport
This commit is contained in:
@@ -53,26 +53,48 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e
|
|||||||
return tlsConn, nil
|
return tlsConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Dialer struct {
|
type Dialer interface {
|
||||||
|
N.Dialer
|
||||||
|
DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultDialer struct {
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
config Config
|
config Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDialer(dialer N.Dialer, config Config) N.Dialer {
|
func NewDialer(dialer N.Dialer, config Config) Dialer {
|
||||||
return &Dialer{dialer, config}
|
return &defaultDialer{dialer, config}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
func (d *defaultDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
if network != N.NetworkTCP {
|
if N.NetworkName(network) != N.NetworkTCP {
|
||||||
return nil, os.ErrInvalid
|
return nil, os.ErrInvalid
|
||||||
}
|
}
|
||||||
conn, err := d.dialer.DialContext(ctx, network, destination)
|
return d.DialTLSContext(ctx, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *defaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *defaultDialer) DialTLSContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
|
||||||
|
return d.dialContext(ctx, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *defaultDialer) dialContext(ctx context.Context, destination M.Socksaddr) (Conn, error) {
|
||||||
|
conn, err := d.dialer.DialContext(ctx, N.NetworkTCP, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return ClientHandshake(ctx, conn, d.config)
|
tlsConn, err := aTLS.ClientHandshake(ctx, conn, d.config)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Dialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (d *defaultDialer) Upstream() any {
|
||||||
return nil, os.ErrInvalid
|
return d.dialer
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import (
|
|||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
aTLS "github.com/sagernet/sing/common/tls"
|
|
||||||
sHTTP "github.com/sagernet/sing/protocol/http"
|
sHTTP "github.com/sagernet/sing/protocol/http"
|
||||||
|
|
||||||
mDNS "github.com/miekg/dns"
|
mDNS "github.com/miekg/dns"
|
||||||
@@ -47,7 +46,7 @@ type HTTPSTransport struct {
|
|||||||
destination *url.URL
|
destination *url.URL
|
||||||
headers http.Header
|
headers http.Header
|
||||||
transportAccess sync.Mutex
|
transportAccess sync.Mutex
|
||||||
transport *http.Transport
|
transport *HTTPSTransportWrapper
|
||||||
transportResetAt time.Time
|
transportResetAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,11 +61,8 @@ func NewHTTPS(ctx context.Context, logger log.ContextLogger, tag string, options
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if common.Error(tlsConfig.Config()) == nil && !common.Contains(tlsConfig.NextProtos(), http2.NextProtoTLS) {
|
if len(tlsConfig.NextProtos()) == 0 {
|
||||||
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), http2.NextProtoTLS))
|
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
|
||||||
}
|
|
||||||
if !common.Contains(tlsConfig.NextProtos(), "http/1.1") {
|
|
||||||
tlsConfig.SetNextProtos(append(tlsConfig.NextProtos(), "http/1.1"))
|
|
||||||
}
|
}
|
||||||
headers := options.Headers.Build()
|
headers := options.Headers.Build()
|
||||||
host := headers.Get("Host")
|
host := headers.Get("Host")
|
||||||
@@ -124,37 +120,13 @@ func NewHTTPSRaw(
|
|||||||
serverAddr M.Socksaddr,
|
serverAddr M.Socksaddr,
|
||||||
tlsConfig tls.Config,
|
tlsConfig tls.Config,
|
||||||
) *HTTPSTransport {
|
) *HTTPSTransport {
|
||||||
var transport *http.Transport
|
|
||||||
if tlsConfig != nil {
|
|
||||||
transport = &http.Transport{
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
tcpConn, hErr := dialer.DialContext(ctx, network, serverAddr)
|
|
||||||
if hErr != nil {
|
|
||||||
return nil, hErr
|
|
||||||
}
|
|
||||||
tlsConn, hErr := aTLS.ClientHandshake(ctx, tcpConn, tlsConfig)
|
|
||||||
if hErr != nil {
|
|
||||||
tcpConn.Close()
|
|
||||||
return nil, hErr
|
|
||||||
}
|
|
||||||
return tlsConn, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.DialContext(ctx, network, serverAddr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &HTTPSTransport{
|
return &HTTPSTransport{
|
||||||
TransportAdapter: adapter,
|
TransportAdapter: adapter,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
destination: destination,
|
destination: destination,
|
||||||
headers: headers,
|
headers: headers,
|
||||||
transport: transport,
|
transport: NewHTTPSTransportWrapper(tls.NewDialer(dialer, tlsConfig), serverAddr),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
80
dns/transport/https_transport.go
Normal file
80
dns/transport/https_transport.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errFallback = E.New("fallback to HTTP/1.1")
|
||||||
|
|
||||||
|
type HTTPSTransportWrapper struct {
|
||||||
|
http2Transport *http2.Transport
|
||||||
|
httpTransport *http.Transport
|
||||||
|
fallback *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPSTransportWrapper(dialer tls.Dialer, serverAddr M.Socksaddr) *HTTPSTransportWrapper {
|
||||||
|
var fallback atomic.Bool
|
||||||
|
return &HTTPSTransportWrapper{
|
||||||
|
http2Transport: &http2.Transport{
|
||||||
|
DialTLSContext: func(ctx context.Context, _, _ string, _ *tls.STDConfig) (net.Conn, error) {
|
||||||
|
tlsConn, err := dialer.DialTLSContext(ctx, serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
state := tlsConn.ConnectionState()
|
||||||
|
if state.NegotiatedProtocol == http2.NextProtoTLS {
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
tlsConn.Close()
|
||||||
|
fallback.Store(true)
|
||||||
|
return nil, errFallback
|
||||||
|
},
|
||||||
|
},
|
||||||
|
httpTransport: &http.Transport{
|
||||||
|
DialTLSContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return dialer.DialTLSContext(ctx, serverAddr)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
fallback: &fallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPSTransportWrapper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
if h.fallback.Load() {
|
||||||
|
return h.httpTransport.RoundTrip(request)
|
||||||
|
} else {
|
||||||
|
response, err := h.http2Transport.RoundTrip(request)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errFallback) {
|
||||||
|
return h.httpTransport.RoundTrip(request)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPSTransportWrapper) CloseIdleConnections() {
|
||||||
|
h.http2Transport.CloseIdleConnections()
|
||||||
|
h.httpTransport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPSTransportWrapper) Clone() *HTTPSTransportWrapper {
|
||||||
|
return &HTTPSTransportWrapper{
|
||||||
|
httpTransport: h.httpTransport,
|
||||||
|
http2Transport: &http2.Transport{
|
||||||
|
DialTLSContext: h.http2Transport.DialTLSContext,
|
||||||
|
},
|
||||||
|
fallback: h.fallback,
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user