diff --git a/sentence_transformers/sparse_encoder/model.py b/sentence_transformers/sparse_encoder/model.py index 3e0da8949..781e47c5b 100644 --- a/sentence_transformers/sparse_encoder/model.py +++ b/sentence_transformers/sparse_encoder/model.py @@ -502,6 +502,10 @@ def encode( max_active_dims = max_active_dims if max_active_dims is not None else self.max_active_dims + forward_kwargs = dict(kwargs) + if max_active_dims is not None: + forward_kwargs["max_active_dims"] = max_active_dims + all_embeddings = [] length_sorted_idx = np.argsort([-self._input_length(sen) for sen in inputs]) if self._can_flatten_inputs(): @@ -514,7 +518,7 @@ def encode( features = batch_to_device(features, device) with torch.inference_mode(): - embeddings = self.forward(features, **kwargs)["sentence_embedding"] + embeddings = self.forward(features, **forward_kwargs)["sentence_embedding"] if max_active_dims is not None: embeddings = select_max_active_dims(embeddings, max_active_dims=max_active_dims) diff --git a/tests/sparse_encoder/test_model.py b/tests/sparse_encoder/test_model.py index f4876a177..fa9a72785 100644 --- a/tests/sparse_encoder/test_model.py +++ b/tests/sparse_encoder/test_model.py @@ -232,6 +232,31 @@ def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEnco assert model[0].sub_modules["document"][0].max_seq_length == 256 +def test_inference_free_splade_max_active_dims_routing(inference_free_splade_bert_tiny_model: SparseEncoder): + model = inference_free_splade_bert_tiny_model + query = "What is the capital of France?" + document = "The capital of France is Paris." + + # Encode without max_active_dims — baseline + query_emb = model.encode_query(query) + doc_emb = model.encode_document(document) + + # Encode with max_active_dims — should route to the same sub-modules + query_emb_mad = model.encode_query(query, max_active_dims=50) + doc_emb_mad = model.encode_document(document, max_active_dims=50) + + # The non-zero indices of the max_active_dims result should be a subset of the baseline + query_baseline_indices = query_emb.coalesce().indices()[0] + query_mad_indices = query_emb_mad.coalesce().indices()[0] + assert set(query_mad_indices.tolist()).issubset(set(query_baseline_indices.tolist())) + assert query_emb_mad._nnz() <= 50 + + doc_baseline_indices = doc_emb.coalesce().indices()[0] + doc_mad_indices = doc_emb_mad.coalesce().indices()[0] + assert set(doc_mad_indices.tolist()).issubset(set(doc_baseline_indices.tolist())) + assert doc_emb_mad._nnz() <= 50 + + def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): """Test that additional parameters are correctly passed to encode""" model = splade_bert_tiny_model @@ -263,6 +288,32 @@ def spy_encode(*args, **kwargs): assert kwargs["custom_param"] == "value" +def test_csr_max_active_dims_passed_to_forward(csr_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): + model = csr_bert_tiny_model + assert isinstance(model[-1], SparseAutoEncoder) + assert model[-1].k == 16 + + # Verify that max_active_dims is passed to SparseAutoEncoder.forward() + forward_calls = [] + original_forward = model[-1].forward + + def spy_forward(*args, **kwargs): + forward_calls.append(kwargs) + return original_forward(*args, **kwargs) + + monkeypatch.setattr(model[-1], "forward", spy_forward) + + model.encode("Hello world", max_active_dims=5) + assert len(forward_calls) == 1 + assert forward_calls[0]["max_active_dims"] == 5 + + # Without max_active_dims, the model's default max_active_dims is used + forward_calls.clear() + model.encode("Hello world") + assert len(forward_calls) == 1 + assert forward_calls[0]["max_active_dims"] == model.max_active_dims + + def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path): splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny")) csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny"))