Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sentence_transformers/sparse_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,10 @@ def encode(

max_active_dims = max_active_dims if max_active_dims is not None else self.max_active_dims

forward_kwargs = dict(kwargs)
if max_active_dims is not None:
forward_kwargs["max_active_dims"] = max_active_dims

all_embeddings = []
length_sorted_idx = np.argsort([-self._input_length(sen) for sen in inputs])
if self._can_flatten_inputs():
Expand All @@ -514,7 +518,7 @@ def encode(
features = batch_to_device(features, device)

with torch.inference_mode():
embeddings = self.forward(features, **kwargs)["sentence_embedding"]
embeddings = self.forward(features, **forward_kwargs)["sentence_embedding"]

if max_active_dims is not None:
embeddings = select_max_active_dims(embeddings, max_active_dims=max_active_dims)
Expand Down
51 changes: 51 additions & 0 deletions tests/sparse_encoder/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,31 @@ def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEnco
assert model[0].sub_modules["document"][0].max_seq_length == 256


def test_inference_free_splade_max_active_dims_routing(inference_free_splade_bert_tiny_model: SparseEncoder):
model = inference_free_splade_bert_tiny_model
query = "What is the capital of France?"
document = "The capital of France is Paris."

# Encode without max_active_dims — baseline
query_emb = model.encode_query(query)
doc_emb = model.encode_document(document)

# Encode with max_active_dims — should route to the same sub-modules
query_emb_mad = model.encode_query(query, max_active_dims=50)
doc_emb_mad = model.encode_document(document, max_active_dims=50)

# The non-zero indices of the max_active_dims result should be a subset of the baseline
query_baseline_indices = query_emb.coalesce().indices()[0]
query_mad_indices = query_emb_mad.coalesce().indices()[0]
assert set(query_mad_indices.tolist()).issubset(set(query_baseline_indices.tolist()))
assert query_emb_mad._nnz() <= 50

doc_baseline_indices = doc_emb.coalesce().indices()[0]
doc_mad_indices = doc_emb_mad.coalesce().indices()[0]
assert set(doc_mad_indices.tolist()).issubset(set(doc_baseline_indices.tolist()))
assert doc_emb_mad._nnz() <= 50


def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch):
"""Test that additional parameters are correctly passed to encode"""
model = splade_bert_tiny_model
Expand Down Expand Up @@ -263,6 +288,32 @@ def spy_encode(*args, **kwargs):
assert kwargs["custom_param"] == "value"


def test_csr_max_active_dims_passed_to_forward(csr_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch):
model = csr_bert_tiny_model
assert isinstance(model[-1], SparseAutoEncoder)
assert model[-1].k == 16

# Verify that max_active_dims is passed to SparseAutoEncoder.forward()
forward_calls = []
original_forward = model[-1].forward

def spy_forward(*args, **kwargs):
forward_calls.append(kwargs)
return original_forward(*args, **kwargs)

monkeypatch.setattr(model[-1], "forward", spy_forward)

model.encode("Hello world", max_active_dims=5)
assert len(forward_calls) == 1
assert forward_calls[0]["max_active_dims"] == 5

# Without max_active_dims, the model's default max_active_dims is used
forward_calls.clear()
model.encode("Hello world")
assert len(forward_calls) == 1
assert forward_calls[0]["max_active_dims"] == model.max_active_dims


def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path):
splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny"))
csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny"))
Expand Down
Loading