Fix shadowtls server

This commit is contained in:
世界
2022-11-21 15:39:34 +08:00
parent 8b7fe20b7f
commit c16e4316d6
4 changed files with 31 additions and 27 deletions

View File

@@ -29,6 +29,7 @@ type ShadowTLS struct {
handshakeAddr M.Socksaddr
v2 bool
password string
fallbackAfter int
}
func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSInboundOptions) (*ShadowTLS, error) {
@@ -52,6 +53,11 @@ func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.Context
case 1:
case 2:
inbound.v2 = true
if options.FallbackAfter == nil {
inbound.fallbackAfter = 2
} else {
inbound.fallbackAfter = *options.FallbackAfter
}
default:
return nil, E.New("unknown shadowtls protocol version: ", options.Version)
}
@@ -85,7 +91,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a
hashConn := shadowtls.NewHashWriteConn(conn, s.password)
go bufio.Copy(hashConn, handshakeConn)
var request *buf.Buffer
request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn)
request, err = s.copyUntilHandshakeFinishedV2(handshakeConn, conn, hashConn, s.fallbackAfter)
if err == nil {
handshakeConn.Close()
return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewConn(conn), request), metadata)
@@ -129,10 +135,10 @@ func (s *ShadowTLS) copyUntilHandshakeFinished(dst io.Writer, src io.Reader) err
}
}
func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn) (*buf.Buffer, error) {
func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, hash *shadowtls.HashWriteConn, fallbackAfter int) (*buf.Buffer, error) {
const applicationData = 0x17
var tlsHdr [5]byte
var doFallback bool
var applicationDataCount int
for {
_, err := io.ReadFull(src, tlsHdr[:])
if err != nil {
@@ -152,13 +158,14 @@ func (s *ShadowTLS) copyUntilHandshakeFinishedV2(dst net.Conn, src io.Reader, ha
}
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), data))
data.Release()
doFallback = true
applicationDataCount++
} else {
_, err = io.Copy(dst, io.MultiReader(bytes.NewReader(tlsHdr[:]), io.LimitReader(src, int64(length))))
}
if err != nil {
return nil, err
} else if doFallback {
}
if applicationDataCount > fallbackAfter {
return nil, os.ErrPermission
}
}