Skip to content

fix: Separate tokenize and forward kwargs in SparseEncoder.encode to prevent Router misrouting#3695

Merged
tomaarsen merged 4 commits intohuggingface:mainfrom
ratatouille-plat:fix/sparse-encoder-kwargs-routing
Apr 9, 2026
Merged

fix: Separate tokenize and forward kwargs in SparseEncoder.encode to prevent Router misrouting#3695
tomaarsen merged 4 commits intohuggingface:mainfrom
ratatouille-plat:fix/sparse-encoder-kwargs-routing

Conversation

@ratatouille-plat
Copy link
Copy Markdown
Contributor

Problem

When using SparseEncoder.encode() (called via encode_query() or encode_document()) with a Router-based model (e.g., inference-free SPLADE), the Router module fails to route to the correct sub-module when max_active_dims is set.

Root Cause

In SparseEncoder.encode(), max_active_dims is injected into kwargs before self.tokenize() is called:

max_active_dims = max_active_dims if max_active_dims is not None else self.max_active_dims
if max_active_dims is not None:
    kwargs["max_active_dims"] = max_active_dims

# ...
features = self.tokenize(sentences_batch, **kwargs)

This kwargs dict (now containing max_active_dims) is passed through the following call chain:

  1. SentenceTransformer.tokenize(texts, **kwargs)
  2. self[0].tokenize(texts, **kwargs) (wrapped in try/except TypeError)
  3. Router.tokenize(texts, task="query", max_active_dims=...)
  4. input_module.tokenize(texts, **kwargs) — e.g., SparseStaticEmbedding.tokenize()

SparseStaticEmbedding.tokenize(self, texts, padding=True) does not accept **kwargs, so passing max_active_dims raises a TypeError.

This TypeError propagates back up to SentenceTransformer.tokenize(), which catches it and falls back to:

except TypeError:
    return self[0].tokenize(texts)  # no task, no kwargs

This fallback calls Router.tokenize(texts) without the task argument, causing the Router to use its default_route instead of the intended "query" or "document" route.

For inference-free SPLADE models with a module structure like:

query_0_SparseStaticEmbedding
document_0_MLMTransformer
document_1_SpladePooling

This means a query may be incorrectly routed to MLMTransformer instead of SparseStaticEmbedding. Since SparseStaticEmbedding.tokenize() hardcodes add_special_tokens=False while MLMTransformer.tokenize() does not, the result is tokenization with special tokens (CLS/SEP) when they should be absent.

Impact

  • Query tokenization silently uses the wrong sub-module via Router misrouting
  • Special tokens are incorrectly included in the tokenized output
  • This produces incorrect sparse embeddings for queries, degrading retrieval quality
  • No error or warning is raised — the failure is completely silent

Solution

Separate kwargs into two dicts:

  • kwargs — passed to self.tokenize(), contains only routing-relevant keys (e.g., task)
  • forward_kwargs — passed to self.forward(), includes max_active_dims on top of the original kwargs

This ensures max_active_dims never leaks into the tokenization call chain, allowing Router to correctly route based on task and the appropriate sub-module's tokenize() to be called without unexpected keyword arguments.

@tomaarsen
Copy link
Copy Markdown
Member

Hello!

Very well spotted, this indeed seems like an issue with a clean fix. Would you be able to add a test for this as well perhaps? I'll be refactoring a lot of this functionality soon, and I want to make sure that I correctly pull this fix into that refactor to avoid a regression.

  • Tom Aarsen

@ratatouille-plat
Copy link
Copy Markdown
Contributor Author

ratatouille-plat commented Mar 27, 2026

@tomaarsen
Thanks for the review! I've added a test (test_inference_free_splade_max_active_dims_routing) that verifies the Router correctly routes to the expected sub-module when max_active_dims is set.

Regarding the CI failures — the only failing test is tests/models/test_static_embedding.py::test_from_distillation, which fails with:

ImportError: To use this method, please install the `model2vec` package: `pip install model2vec[distill]`

This appears to be unrelated to this PR (a missing dependency in the CI environment). The latest commits are also pending maintainer approval to run CI (action_required status) since this is a fork-based PR.

@ratatouille-plat
Copy link
Copy Markdown
Contributor Author

@tomaarsen Hi, I was wondering about the current status of this PR. Please let me know if there's anything else needed from my side.

@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Apr 9, 2026

Hello @ratatouille-plat, no changes on your side needed! I'd like to incorporate your changes today, but I wanted to prioritize #3554 first. I've merged that one now, resulting in a lot of merge conflicts, but I'll take over from here to get this fix merged.

Thanks a lot for finding this and setting up the PR!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 802317b into huggingface:main Apr 9, 2026
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants