fix: Separate tokenize and forward kwargs in SparseEncoder.encode to prevent Router misrouting#3695
Conversation
…prevent Router misrouting
|
Hello! Very well spotted, this indeed seems like an issue with a clean fix. Would you be able to add a test for this as well perhaps? I'll be refactoring a lot of this functionality soon, and I want to make sure that I correctly pull this fix into that refactor to avoid a regression.
|
|
@tomaarsen Regarding the CI failures — the only failing test is This appears to be unrelated to this PR (a missing dependency in the CI environment). The latest commits are also pending maintainer approval to run CI ( |
|
@tomaarsen Hi, I was wondering about the current status of this PR. Please let me know if there's anything else needed from my side. |
|
Hello @ratatouille-plat, no changes on your side needed! I'd like to incorporate your changes today, but I wanted to prioritize #3554 first. I've merged that one now, resulting in a lot of merge conflicts, but I'll take over from here to get this fix merged. Thanks a lot for finding this and setting up the PR!
|
Problem
When using
SparseEncoder.encode()(called viaencode_query()orencode_document()) with aRouter-based model (e.g., inference-free SPLADE), theRoutermodule fails to route to the correct sub-module whenmax_active_dimsis set.Root Cause
In
SparseEncoder.encode(),max_active_dimsis injected intokwargsbeforeself.tokenize()is called:This
kwargsdict (now containingmax_active_dims) is passed through the following call chain:SentenceTransformer.tokenize(texts, **kwargs)self[0].tokenize(texts, **kwargs)(wrapped in try/except TypeError)Router.tokenize(texts, task="query", max_active_dims=...)input_module.tokenize(texts, **kwargs)— e.g.,SparseStaticEmbedding.tokenize()SparseStaticEmbedding.tokenize(self, texts, padding=True)does not accept**kwargs, so passingmax_active_dimsraises aTypeError.This
TypeErrorpropagates back up toSentenceTransformer.tokenize(), which catches it and falls back to:This fallback calls
Router.tokenize(texts)without thetaskargument, causing the Router to use itsdefault_routeinstead of the intended"query"or"document"route.For inference-free SPLADE models with a module structure like:
This means a query may be incorrectly routed to
MLMTransformerinstead ofSparseStaticEmbedding. SinceSparseStaticEmbedding.tokenize()hardcodesadd_special_tokens=FalsewhileMLMTransformer.tokenize()does not, the result is tokenization with special tokens (CLS/SEP) when they should be absent.Impact
Solution
Separate
kwargsinto two dicts:kwargs— passed toself.tokenize(), contains only routing-relevant keys (e.g.,task)forward_kwargs— passed toself.forward(), includesmax_active_dimson top of the originalkwargsThis ensures
max_active_dimsnever leaks into the tokenization call chain, allowingRouterto correctly route based ontaskand the appropriate sub-module'stokenize()to be called without unexpected keyword arguments.