MUX: Prevent goroutine leak (#5110)

This commit is contained in:
patterniha
2025-09-10 02:33:19 +02:00
committed by GitHub
parent ce5c51d3ba
commit 9f5dcb1591
7 changed files with 109 additions and 33 deletions

View File

@@ -219,14 +219,16 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
return m.done.Wait()
}
func (m *ClientWorker) GetTimer() *time.Ticker {
return m.timer
func (m *ClientWorker) Close() error {
return m.done.Close()
}
func (m *ClientWorker) monitor() {
defer m.timer.Stop()
for {
checkSize := m.sessionManager.Size()
checkCount := m.sessionManager.Count()
select {
case <-m.done.Wait():
m.sessionManager.Close()
@@ -234,8 +236,7 @@ func (m *ClientWorker) monitor() {
common.Interrupt(m.link.Reader)
return
case <-m.timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
common.Must(m.done.Close())
}
}
@@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream
@@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
defer timer.Reset(time.Second * 16)
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil {
@@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
}
s.input = link.Reader
s.output = link.Writer
if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer, m.timer)
} else {
fetchInput(ctx, s, m.link.Writer, m.timer)
go fetchInput(ctx, s, m.link.Writer)
if _, ok := link.Reader.(*pipe.Reader); !ok {
select {
case <-ctx.Done():
case <-s.done.Wait():
}
}
return true
}

View File

@@ -3,6 +3,7 @@ package mux
import (
"context"
"io"
"time"
"github.com/xtls/xray-core/app/dispatcher"
"github.com/xtls/xray-core/common"
@@ -12,6 +13,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
@@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
return s.dispatcher.DispatchLink(ctx, dest, link)
}
link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
_, err := NewServerWorker(ctx, s.dispatcher, link)
return err
worker, err := NewServerWorker(ctx, s.dispatcher, link)
if err != nil {
return err
}
select {
case <-ctx.Done():
case <-worker.done.Wait():
}
return nil
}
// Start implements common.Runnable.
@@ -81,6 +90,8 @@ type ServerWorker struct {
dispatcher routing.Dispatcher
link *transport.Link
sessionManager *SessionManager
done *done.Instance
timer *time.Ticker
}
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
@@ -88,15 +99,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
dispatcher: d,
link: link,
sessionManager: NewSessionManager(),
done: done.New(),
timer: time.NewTicker(60 * time.Second),
}
if inbound := session.InboundFromContext(ctx); inbound != nil {
inbound.CanSpliceCopy = 3
}
if _, ok := link.Reader.(*pipe.Reader); ok {
go worker.run(ctx)
} else {
worker.run(ctx)
}
go worker.run(ctx)
go worker.monitor()
return worker, nil
}
@@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
s.Close(false)
}
func (w *ServerWorker) monitor() {
defer w.timer.Stop()
for {
checkSize := w.sessionManager.Size()
checkCount := w.sessionManager.Count()
select {
case <-w.done.Wait():
w.sessionManager.Close()
common.Interrupt(w.link.Writer)
common.Interrupt(w.link.Reader)
return
case <-w.timer.C:
if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
common.Must(w.done.Close())
}
}
}
}
func (w *ServerWorker) ActiveConnections() uint32 {
return uint32(w.sessionManager.Size())
}
func (w *ServerWorker) Closed() bool {
return w.sessionManager.Closed()
return w.done.Done()
}
func (w *ServerWorker) WaitClosed() <-chan struct{} {
return w.done.Wait()
}
func (w *ServerWorker) Close() error {
return w.done.Close()
}
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
@@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
}
func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader}
defer func() {
common.Must(w.done.Close())
}()
defer w.sessionManager.Close()
defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
reader := &buf.BufferedReader{Reader: w.link.Reader}
for {
select {

View File

@@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/transport/pipe"
)
@@ -53,7 +54,7 @@ func (m *SessionManager) Count() int {
func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
m.Lock()
defer m.Unlock()
MaxConcurrency := int(Strategy.MaxConcurrency)
MaxConnection := uint16(Strategy.MaxConnection)
@@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
s := &Session{
ID: m.count,
parent: m,
done: done.New(),
}
m.sessions[s.ID] = s
return s
@@ -115,7 +117,7 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
return s, found
}
func (m *SessionManager) CloseIfNoSession() bool {
func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
m.Lock()
defer m.Unlock()
@@ -123,11 +125,13 @@ func (m *SessionManager) CloseIfNoSession() bool {
return true
}
if len(m.sessions) != 0 {
if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
return false
}
m.closed = true
m.sessions = nil
return true
}
@@ -157,6 +161,7 @@ type Session struct {
ID uint16
transferType protocol.TransferType
closed bool
done *done.Instance
XUDP *XUDP
}
@@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
return nil
}
s.closed = true
if s.done != nil {
s.done.Close()
}
if s.XUDP == nil {
common.Interrupt(s.input)
common.Close(s.output)

View File

@@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager()
s := m.Allocate(&ClientStrategy{})
if m.CloseIfNoSession() {
if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("able to close")
}
m.Remove(false, s.ID)
if !m.CloseIfNoSession() {
if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("not able to close")
}
}