diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 2ae64902..03722e8a 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -196,7 +196,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *tran return inboundLink, outboundLink } -func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link { +func WrapLink(ctx context.Context, policyManager policy.Manager, statsManager stats.Manager, link *transport.Link) *transport.Link { sessionInbound := session.InboundFromContext(ctx) var user *protocol.MemoryUser if sessionInbound != nil { @@ -206,16 +206,16 @@ func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader} if user != nil && len(user.Email) > 0 { - p := d.policy.ForLevel(user.Level) + p := policyManager.ForLevel(user.Level) if p.Stats.UserUplink { name := "user>>>" + user.Email + ">>>traffic>>>uplink" - if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { + if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil { link.Reader.(*buf.TimeoutWrapperReader).Counter = c } } if p.Stats.UserDownlink { name := "user>>>" + user.Email + ">>>traffic>>>downlink" - if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { + if c, _ := stats.GetOrRegisterCounter(statsManager, name); c != nil { link.Writer = &SizeStatWriter{ Counter: c, Writer: link.Writer, @@ -224,7 +224,7 @@ func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) } if p.Stats.UserOnline { name := "user>>>" + user.Email + ">>>online" - if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil { + if om, _ := stats.GetOrRegisterOnlineMap(statsManager, name); om != nil { sessionInbounds := session.InboundFromContext(ctx) userIP := sessionInbounds.Source.Address.String() om.AddIP(userIP) @@ -357,7 +357,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De content = new(session.Content) ctx = session.ContextWithContent(ctx, content) } - outbound = d.WrapLink(ctx, outbound) + outbound = WrapLink(ctx, d.policy, d.stats, outbound) sniffingRequest := content.SniffingRequest if !sniffingRequest.Enabled { d.routedDispatch(ctx, outbound, destination) @@ -449,6 +449,7 @@ func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, netw } return contentResult, contentErr } + func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index 324fea59..f6dfec48 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -229,10 +229,6 @@ func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, l } return w.Dispatcher.DispatchLink(ctx, dest, link) } - - if d, ok := w.Dispatcher.(routing.WrapLinkDispatcher); ok { - link = d.WrapLink(ctx, link) - } w.handleInternalConn(link) return nil diff --git a/common/mux/server.go b/common/mux/server.go index 1c090185..d1cdac11 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -63,9 +63,6 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t if dest.Address != muxCoolAddress { return s.dispatcher.DispatchLink(ctx, dest, link) } - if d, ok := s.dispatcher.(routing.WrapLinkDispatcher); ok { - link = d.WrapLink(ctx, link) - } worker, err := NewServerWorker(ctx, s.dispatcher, link) if err != nil { return err diff --git a/features/routing/dispatcher.go b/features/routing/dispatcher.go index c8354446..53d3bf90 100644 --- a/features/routing/dispatcher.go +++ b/features/routing/dispatcher.go @@ -26,9 +26,3 @@ type Dispatcher interface { func DispatcherType() interface{} { return (*Dispatcher)(nil) } - -// Just for type assertion -type WrapLinkDispatcher interface { - Dispatcher - WrapLink(ctx context.Context, link *transport.Link) *transport.Link -} diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index eeb1a25f..d12495b4 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -12,6 +12,7 @@ import ( "time" "unsafe" + "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/reverse" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -31,6 +32,7 @@ import ( "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/features/stats" "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/vless" "github.com/xtls/xray-core/proxy/vless/encoding" @@ -72,10 +74,11 @@ func init() { type Handler struct { inboundHandlerManager feature_inbound.Manager policyManager policy.Manager + stats stats.Manager validator vless.Validator decryption *encryption.ServerInstance outboundHandlerManager outbound.Manager - wrapLink func(ctx context.Context, link *transport.Link) *transport.Link + defaultDispatcher routing.Dispatcher ctx context.Context fallbacks map[string]map[string]map[string]*Fallback // or nil // regexps map[string]*regexp.Regexp // or nil @@ -84,16 +87,13 @@ type Handler struct { // New creates a new VLess inbound handler. func New(ctx context.Context, config *Config, dc dns.Client, validator vless.Validator) (*Handler, error) { v := core.MustFromContext(ctx) - var wrapLinkFunc func(ctx context.Context, link *transport.Link) *transport.Link - if dispatcher, ok := v.GetFeature(routing.DispatcherType()).(routing.WrapLinkDispatcher); ok { - wrapLinkFunc = dispatcher.WrapLink - } handler := &Handler{ inboundHandlerManager: v.GetFeature(feature_inbound.ManagerType()).(feature_inbound.Manager), policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), + stats: v.GetFeature(stats.ManagerType()).(stats.Manager), validator: validator, outboundHandlerManager: v.GetFeature(outbound.ManagerType()).(outbound.Manager), - wrapLink: wrapLinkFunc, + defaultDispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher), ctx: ctx, } @@ -264,7 +264,7 @@ func (*Handler) Network() []net.Network { } // Process implements proxy.Inbound.Process(). -func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { +func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatch routing.Dispatcher) error { iConn := stat.TryUnwrapStatsConn(connection) if h.decryption != nil { @@ -623,13 +623,10 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if err != nil { return err } - if h.wrapLink == nil { - return errors.New("VLESS reverse must have a dispatcher that implemented routing.WrapLinkDispatcher") - } - return r.NewMux(ctx, h.wrapLink(ctx, &transport.Link{Reader: clientReader, Writer: clientWriter})) + return r.NewMux(ctx, dispatcher.WrapLink(ctx, h.policyManager, h.stats, &transport.Link{Reader: clientReader, Writer: clientWriter})) } - if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{ + if err := dispatch.DispatchLink(ctx, request.Destination(), &transport.Link{ Reader: clientReader, Writer: clientWriter}, ); err != nil {