From ca74129a3ed431de2506e2a85d4353a134338942 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Fri, 20 Feb 2026 17:34:42 +0900 Subject: [PATCH 01/12] Add CachedSpladeLoss --- .../sparse_encoder/losses/CachedSpladeLoss.py | 333 ++++++++++++++++++ .../sparse_encoder/losses/__init__.py | 2 + 2 files changed, 335 insertions(+) create mode 100644 sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py diff --git a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py new file mode 100644 index 000000000..7899ae366 --- /dev/null +++ b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterable, Iterator +from contextlib import nullcontext +from functools import partial +from typing import Any, Literal + +import torch +import tqdm +from torch import Tensor, nn + +from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext, _backward_hook +from sentence_transformers.sparse_encoder.losses.FlopsLoss import FlopsLoss +from sentence_transformers.sparse_encoder.losses.SpladeLoss import SpladeLoss +from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder + +logger = logging.getLogger(__name__) + + +class CachedSpladeLoss(SpladeLoss): + def __init__( + self, + model: SparseEncoder, + loss: nn.Module, + document_regularizer_weight: float, + query_regularizer_weight: float | None = None, + document_regularizer: nn.Module | None = None, + query_regularizer: nn.Module | None = None, + document_regularizer_threshold: int | None = None, + query_regularizer_threshold: int | None = None, + use_document_regularizer_only: bool = False, + mini_batch_size: int = 32, + show_progress_bar: bool = False, + ): + """ + Cached version of :class:`SpladeLoss` that uses the GradCache technique to allow for much larger + effective batch sizes without additional GPU memory usage. + + The key insight is that caching happens at the SpladeLoss (wrapper) level, not at the base loss level. + This resolves the fundamental conflict between: + + - **Cached losses** needing raw inputs to compute embeddings in mini-batches. + - **SpladeLoss** needing to compute embeddings once and share them with both the base loss and + the FLOPS regularizer. + + By performing the GradCache mini-batch embedding at the SpladeLoss level, both the base loss and + regularizers still receive pre-computed embeddings via ``compute_loss_from_embeddings()`` — no + changes to base losses or regularizers are needed. + + In detail, the GradCache technique works as follows: + + (1) A quick embedding step without gradients/computation graphs to get all embeddings in mini-batches; + (2) Calculate the combined loss (base + regularizers), backward up to the embeddings and cache the + gradients w.r.t. the embeddings; + (3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the + backward chain. + + Args: + model: SparseEncoder model + loss: The principal loss function to use can be any of the SparseEncoder losses except CSR related + losses and flops loss. Must have a ``compute_loss_from_embeddings`` method. + document_regularizer_weight: Weight for the corpus regularization term. This term encourages sparsity + in the document embeddings. In some papers, this parameter is referred to as "lambda_d" (document) + or "lambda_c" (corpus). + query_regularizer_weight: Weight for the query regularization term. This term encourages sparsity in + the query embeddings. If None, no query regularization will be applied. In some papers, this + parameter is referred to as "lambda_q" (query). + document_regularizer: Optional regularizer to use specifically for corpus regularization instead of the + default FlopsLoss. + query_regularizer: Optional regularizer to use specifically for query regularization instead of the + default FlopsLoss. + document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the + corpus embeddings to be considered in the FlopsLoss. + query_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the + query embeddings to be considered in the FlopsLoss. + use_document_regularizer_only: If True, all input embeddings are treated as documents and regularized + together with document_regularizer_weight. + mini_batch_size: Mini-batch size for the forward pass, this denotes how much memory is actually used + during training and evaluation. The larger the mini-batch size, the more memory efficient the + training is, but the slower the training will be. It's recommended to set it as high as your GPU + memory allows. The default value is 32. + show_progress_bar: If True, a progress bar for the mini-batches is shown during training. + + References: + - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: + https://huggingface.co/papers/2101.06983 + - From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective: + https://huggingface.co/papers/2205.04733 + + Requirements: + 1. Input requirements depend on the chosen loss + 2. Should be used with large ``per_device_train_batch_size`` and low ``mini_batch_size`` for superior + performance, but slower training time than :class:`SpladeLoss`. + + Example: + :: + + from datasets import Dataset + + from sentence_transformers.sparse_encoder import SparseEncoder, SparseEncoderTrainer, losses + + model = SparseEncoder("distilbert/distilbert-base-uncased") + train_dataset = Dataset.from_dict( + { + "anchor": ["It's nice weather outside today.", "He drove to work."], + "positive": ["It's so sunny.", "He took the car to the office."], + } + ) + loss = losses.CachedSpladeLoss( + model=model, + loss=losses.SparseMultipleNegativesRankingLoss(model), + document_regularizer_weight=3e-5, + query_regularizer_weight=5e-5, + mini_batch_size=32, + ) + + trainer = SparseEncoderTrainer(model=model, train_dataset=train_dataset, loss=loss) + trainer.train() + """ + super().__init__( + model=model, + loss=loss, + document_regularizer_weight=document_regularizer_weight, + query_regularizer_weight=query_regularizer_weight, + document_regularizer=document_regularizer, + query_regularizer=query_regularizer, + document_regularizer_threshold=document_regularizer_threshold, + query_regularizer_threshold=query_regularizer_threshold, + use_document_regularizer_only=use_document_regularizer_only, + ) + self.mini_batch_size = mini_batch_size + self.show_progress_bar = show_progress_bar + self.cache: list[list[Tensor]] | None = None + self.random_states: list[list[RandContext]] | None = None + + def embed_minibatch( + self, + sentence_feature: dict[str, Tensor], + begin: int, + end: int, + with_grad: bool, + copy_random_state: bool, + random_state: RandContext | None = None, + ) -> tuple[Tensor, RandContext | None]: + """Embed a mini-batch of sentences.""" + grad_context = nullcontext if with_grad else torch.no_grad + random_state_context = nullcontext() if random_state is None else random_state + sentence_feature_minibatch = { + key: value[begin:end] if isinstance(value, torch.Tensor) else value + for key, value in sentence_feature.items() + } + with random_state_context: + with grad_context(): + random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None + reps = self.model(sentence_feature_minibatch)["sentence_embedding"] + return reps, random_state + + def embed_minibatch_iter( + self, + sentence_feature: dict[str, Tensor], + with_grad: bool, + copy_random_state: bool, + random_states: list[RandContext] | None = None, + ) -> Iterator[tuple[Tensor, RandContext | None]]: + """Iterate over mini-batches of sentences for embedding.""" + input_ids: Tensor = sentence_feature["input_ids"] + batch_size = input_ids.shape[0] + for i, begin in enumerate( + tqdm.trange( + 0, + batch_size, + self.mini_batch_size, + desc="Embed mini-batches", + disable=not self.show_progress_bar, + ) + ): + end = begin + self.mini_batch_size + reps, random_state = self.embed_minibatch( + sentence_feature=sentence_feature, + begin=begin, + end=end, + with_grad=with_grad, + copy_random_state=copy_random_state, + random_state=None if random_states is None else random_states[i], + ) + yield reps, random_state + + def calculate_loss_and_cache_gradients( + self, reps: list[list[Tensor]], labels: Tensor | None + ) -> Tensor: + """Calculate the combined loss (base + regularizers) and cache gradients w.r.t. the embeddings.""" + loss = self._compute_total_loss(reps, labels, with_backward=True) + loss = loss.detach().requires_grad_() + self.cache = [[r.grad for r in rs] for rs in reps] + return loss + + def _compute_total_loss( + self, reps: list[list[Tensor]], labels: Tensor | None, with_backward: bool = False + ) -> Tensor: + """Compute total loss from base loss + regularizers on mini-batch reps.""" + embeddings = [torch.cat(r) for r in reps] + + # Base loss + base_loss = self.loss.compute_loss_from_embeddings(embeddings, labels) + if isinstance(base_loss, dict): + total_loss = sum(base_loss.values()) + else: + total_loss = base_loss + self._base_loss_value = total_loss.detach().item() + + # Document regularizer + if self.use_document_regularizer_only: + corpus_emb = torch.cat(embeddings) + else: + corpus_emb = torch.cat(embeddings[1:]) + doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(corpus_emb) + weighted_doc_reg = doc_reg_loss * self.document_regularizer_weight + self._doc_reg_value = weighted_doc_reg.detach().item() + total_loss = total_loss + weighted_doc_reg + + # Query regularizer + if self.query_regularizer_weight is not None: + query_reg_loss = self.query_regularizer.compute_loss_from_embeddings(embeddings[0]) + weighted_query_reg = query_reg_loss * self.query_regularizer_weight + self._query_reg_value = weighted_query_reg.detach().item() + total_loss = total_loss + weighted_query_reg + else: + self._query_reg_value = None + + if with_backward: + total_loss.backward() + total_loss = total_loss.detach() + + return total_loss + + def forward( + self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor | None = None + ) -> dict[str, Tensor] | Tensor: + sentence_features = list(sentence_features) + + # Step (1): Embed all mini-batches without gradients to get all embeddings + reps = [] + self.random_states = [] + for sentence_feature in sentence_features: + reps_mbs = [] + random_state_mbs = [] + for reps_mb, random_state in self.embed_minibatch_iter( + sentence_feature=sentence_feature, + with_grad=False, + copy_random_state=True, + ): + reps_mbs.append(reps_mb.detach().requires_grad_()) + random_state_mbs.append(random_state) + reps.append(reps_mbs) + self.random_states.append(random_state_mbs) + + if torch.is_grad_enabled(): + # Step (2): Calculate the combined loss, backward to embeddings and cache gradients + loss = self.calculate_loss_and_cache_gradients(reps, labels) + + # Step (3): Register backward hook to chain cached gradients back through the model + loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self)) + + # Build dict for loss component logging while preserving gradient flow through `loss`. + # The trainer sums all dict values for backward, so we put the gradient-carrying tensor + # in "base_loss" and use detached tensors for regularizer entries. + device = loss.device + result = {} + non_base_sum = torch.tensor(0.0, device=device) + + doc_reg_detached = torch.tensor(self._doc_reg_value, device=device) + result["document_regularizer_loss"] = doc_reg_detached + non_base_sum = non_base_sum + doc_reg_detached + + if self._query_reg_value is not None: + query_reg_detached = torch.tensor(self._query_reg_value, device=device) + result["query_regularizer_loss"] = query_reg_detached + non_base_sum = non_base_sum + query_reg_detached + + # base_loss = loss - regularizers, so trainer's sum(values) = loss (exact gradient flow) + result["base_loss"] = loss - non_base_sum + + return result + else: + # Eval mode: no caching needed, compute losses directly + embeddings = [torch.cat(r) for r in reps] + losses = {} + + base_loss = self.loss.compute_loss_from_embeddings(embeddings, labels) + if isinstance(base_loss, dict): + losses.update(base_loss) + else: + losses["base_loss"] = base_loss + + if self.use_document_regularizer_only: + corpus_emb = torch.cat(embeddings) + else: + corpus_emb = torch.cat(embeddings[1:]) + doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(corpus_emb) + losses["document_regularizer_loss"] = doc_reg_loss * self.document_regularizer_weight + + if self.query_regularizer_weight is not None: + query_reg_loss = self.query_regularizer.compute_loss_from_embeddings(embeddings[0]) + losses["query_regularizer_loss"] = query_reg_loss * self.query_regularizer_weight + + return losses + + def get_config_dict(self) -> dict[str, Any]: + config = super().get_config_dict() + config["mini_batch_size"] = self.mini_batch_size + return config + + @property + def citation(self) -> str: + return """ +@misc{gao2021scaling, + title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup}, + author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan}, + year={2021}, + eprint={2101.06983}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +@misc{formal2022distillationhardnegativesampling, + title={From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective}, + author={Thibault Formal and Carlos Lassance and Benjamin Piwowarski and St\\'ephane Clinchant}, + year={2022}, + eprint={2205.04733}, + archivePrefix={arXiv}, + primaryClass={cs.IR}, +} +""" diff --git a/sentence_transformers/sparse_encoder/losses/__init__.py b/sentence_transformers/sparse_encoder/losses/__init__.py index d1ad8666c..64b635c82 100644 --- a/sentence_transformers/sparse_encoder/losses/__init__.py +++ b/sentence_transformers/sparse_encoder/losses/__init__.py @@ -11,8 +11,10 @@ from .SparseMultipleNegativesRankingLoss import SparseMultipleNegativesRankingLoss from .SparseTripletLoss import SparseTripletLoss from .SpladeLoss import SpladeLoss +from .CachedSpladeLoss import CachedSpladeLoss __all__ = [ + "CachedSpladeLoss", "CSRLoss", "CSRReconstructionLoss", "SparseMultipleNegativesRankingLoss", From fca2d0079aa1a033d0305e1d56799829945cd736 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Fri, 20 Feb 2026 23:38:40 +0900 Subject: [PATCH 02/12] Apply ruff formatting --- .../sparse_encoder/losses/CachedSpladeLoss.py | 7 ++----- sentence_transformers/sparse_encoder/losses/__init__.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py index 7899ae366..c97e7c323 100644 --- a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py +++ b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py @@ -4,14 +4,13 @@ from collections.abc import Iterable, Iterator from contextlib import nullcontext from functools import partial -from typing import Any, Literal +from typing import Any import torch import tqdm from torch import Tensor, nn from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext, _backward_hook -from sentence_transformers.sparse_encoder.losses.FlopsLoss import FlopsLoss from sentence_transformers.sparse_encoder.losses.SpladeLoss import SpladeLoss from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder @@ -186,9 +185,7 @@ def embed_minibatch_iter( ) yield reps, random_state - def calculate_loss_and_cache_gradients( - self, reps: list[list[Tensor]], labels: Tensor | None - ) -> Tensor: + def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], labels: Tensor | None) -> Tensor: """Calculate the combined loss (base + regularizers) and cache gradients w.r.t. the embeddings.""" loss = self._compute_total_loss(reps, labels, with_backward=True) loss = loss.detach().requires_grad_() diff --git a/sentence_transformers/sparse_encoder/losses/__init__.py b/sentence_transformers/sparse_encoder/losses/__init__.py index 64b635c82..2f3aa317c 100644 --- a/sentence_transformers/sparse_encoder/losses/__init__.py +++ b/sentence_transformers/sparse_encoder/losses/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .CachedSpladeLoss import CachedSpladeLoss from .CSRLoss import CSRLoss, CSRReconstructionLoss from .FlopsLoss import FlopsLoss from .SparseAnglELoss import SparseAnglELoss @@ -11,7 +12,6 @@ from .SparseMultipleNegativesRankingLoss import SparseMultipleNegativesRankingLoss from .SparseTripletLoss import SparseTripletLoss from .SpladeLoss import SpladeLoss -from .CachedSpladeLoss import CachedSpladeLoss __all__ = [ "CachedSpladeLoss", From 8ae14291c1729d82de2401e6e51ba32fd5970fcb Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 10:58:03 +0100 Subject: [PATCH 03/12] Add a cached SPLADE training script --- .../training/retrievers/train_splade_nq.py | 37 +++-- .../retrievers/train_splade_nq_cached.py | 133 ++++++++++++++++++ 2 files changed, 151 insertions(+), 19 deletions(-) create mode 100644 examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py diff --git a/examples/sparse_encoder/training/retrievers/train_splade_nq.py b/examples/sparse_encoder/training/retrievers/train_splade_nq.py index 3e754fcb8..2ea0458e8 100644 --- a/examples/sparse_encoder/training/retrievers/train_splade_nq.py +++ b/examples/sparse_encoder/training/retrievers/train_splade_nq.py @@ -29,9 +29,8 @@ def main(): model_name = "distilbert/distilbert-base-uncased" - - train_batch_size = 12 - num_epochs = 1 + short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] + global_batch_size = 16 # 1a. Load a model to finetune with 1b. (Optional) model card data model = SparseEncoder( @@ -39,7 +38,7 @@ def main(): model_card_data=SparseEncoderModelCardData( language="en", license="apache-2.0", - model_name="splade-distilbert-base-uncased trained on Natural Questions", + model_name=f"splade-{short_model_name} trained on Natural Questions", ), ) model.max_seq_length = 256 # Set the max sequence length to 256 for the training @@ -67,34 +66,35 @@ def main(): # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator evaluator = evaluation.SparseNanoBEIREvaluator( - dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=train_batch_size + dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=global_batch_size ) + evaluator(model) # 5. Define the training arguments - short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] run_name = f"splade-{short_model_name}-nq" training_args = SparseEncoderTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: - num_train_epochs=num_epochs, - per_device_train_batch_size=train_batch_size, - per_device_eval_batch_size=train_batch_size, - learning_rate=2e-5, + num_train_epochs=1, + per_device_train_batch_size=global_batch_size, + per_device_eval_batch_size=global_batch_size, + warmup_ratio=0.1, + learning_rate=2e-6, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch - load_best_model_at_end=True, - metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10", # Optional tracking/debugging parameters: eval_strategy="steps", - eval_steps=1650, + eval_steps=0.2, save_strategy="steps", - save_steps=1650, + save_steps=0.2, save_total_limit=2, - logging_steps=200, + logging_steps=0.05, run_name=run_name, # Will be used in W&B if `wandb` is installed - seed=42, + # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance + # load_best_model_at_end=True, + # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10", ) # 6. Create the trainer & start training @@ -108,9 +108,8 @@ def main(): ) trainer.train() - # 7. Evaluate the final model, using the complete NanoBEIR dataset - test_evaluator = evaluation.SparseNanoBEIREvaluator(show_progress_bar=True, batch_size=train_batch_size) - test_evaluator(model) + # 7. Evaluate the final model again + evaluator(model) # 8. Save the final model final_output_dir = f"models/{run_name}/final" diff --git a/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py new file mode 100644 index 000000000..5b87f0c9e --- /dev/null +++ b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py @@ -0,0 +1,133 @@ +""" +This example trains a SparseEncoder for the Natural Questions (NQ) dataset. +The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval. +It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets, +and trains the model as a retriever. After training, the model is evaluated and saved locally, +with an optional step to push the trained model to the Hugging Face Hub. + +Usage: +python train_splade_nq.py +""" + +import logging +import traceback + +from datasets import load_dataset + +from sentence_transformers import ( + SparseEncoder, + SparseEncoderModelCardData, + SparseEncoderTrainer, + SparseEncoderTrainingArguments, +) +from sentence_transformers.sparse_encoder import evaluation, losses +from sentence_transformers.training_args import BatchSamplers + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) + + +def main(): + model_name = "distilbert/distilbert-base-uncased" + short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] + global_batch_size = 512 + mini_batch_size = 32 + + # 1a. Load a model to finetune with 1b. (Optional) model card data + model = SparseEncoder( + model_name, + model_card_data=SparseEncoderModelCardData( + language="en", + license="apache-2.0", + model_name=f"splade-{short_model_name} trained on Natural Questions", + ), + ) + model.max_seq_length = 256 # Set the max sequence length to 256 for the training + logging.info("Model max length: %s", model.max_seq_length) + + # 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions + logging.info("Read the Natural Questions training dataset") + full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000)) + dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12) + train_dataset = dataset_dict["train"] + eval_dataset = dataset_dict["test"] + logging.info(train_dataset) + logging.info(eval_dataset) + + # 3. Define our training loss. + query_regularizer_weight = 5e-5 + document_regularizer_weight = 3e-5 + + loss = losses.CachedSpladeLoss( + model=model, + loss=losses.SparseMultipleNegativesRankingLoss(model=model), + mini_batch_size=mini_batch_size, + query_regularizer_weight=query_regularizer_weight, # Weight for query loss + document_regularizer_weight=document_regularizer_weight, # Weight for document loss + ) + + # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator + evaluator = evaluation.SparseNanoBEIREvaluator( + dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size + ) + evaluator(model) + + # 5. Define the training arguments + run_name = f"splade-{short_model_name}-nq-512bs" + training_args = SparseEncoderTrainingArguments( + # Required parameter: + output_dir=f"models/{run_name}", + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=global_batch_size, + per_device_eval_batch_size=global_batch_size, + warmup_ratio=0.1, + learning_rate=1e-5, + fp16=False, # Set to False if you get an error that your GPU can't run on FP16 + bf16=True, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=0.2, + save_strategy="steps", + save_steps=0.2, + save_total_limit=2, + logging_steps=0.05, + run_name=run_name, # Will be used in W&B if `wandb` is installed + # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance + # load_best_model_at_end=True, + # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10", + ) + + # 6. Create the trainer & start training + trainer = SparseEncoderTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=loss, + evaluator=evaluator, + ) + trainer.train() + + # 7. Evaluate the final model again + evaluator(model) + + # 8. Save the final model + final_output_dir = f"models/{run_name}/final" + model.save_pretrained(final_output_dir) + + # 9. (Optional) save the model to the Hugging Face Hub! + # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first + try: + model.push_to_hub(run_name) + except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + ) + + +if __name__ == "__main__": + main() From cf5bf483fb92c65861c168c68d6a4a92518eeb63 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 11:35:55 +0100 Subject: [PATCH 04/12] Update docs to mention CachedSpladeLoss --- .../sparse_encoder/losses.md | 5 +++ docs/sparse_encoder/loss_overview.md | 38 ++++++++++--------- .../sparse_encoder/losses/CachedSpladeLoss.py | 7 ---- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/package_reference/sparse_encoder/losses.md b/docs/package_reference/sparse_encoder/losses.md index 95b0a3961..d7048c54f 100644 --- a/docs/package_reference/sparse_encoder/losses.md +++ b/docs/package_reference/sparse_encoder/losses.md @@ -15,6 +15,11 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui .. autoclass:: sentence_transformers.sparse_encoder.losses.SpladeLoss ``` +## CachedSpladeLoss +```{eval-rst} +.. autoclass:: sentence_transformers.sparse_encoder.losses.CachedSpladeLoss +``` + ## FlopsLoss ```{eval-rst} .. autoclass:: sentence_transformers.sparse_encoder.losses.FlopsLoss diff --git a/docs/sparse_encoder/loss_overview.md b/docs/sparse_encoder/loss_overview.md index d345a6b17..776a54b65 100644 --- a/docs/sparse_encoder/loss_overview.md +++ b/docs/sparse_encoder/loss_overview.md @@ -2,7 +2,7 @@ ```{eval-rst} .. warning:: - To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding. + To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss`, :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss`, or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding. ``` @@ -10,21 +10,25 @@ ### SPLADE Loss -The SpladeLoss implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to control efficiency: +The SpladeLoss implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to balance effectiveness and efficiency: -- Supports all the losses mention below as main loss but three principal loss types: SparseMultipleNegativesRankingLoss, SparseMarginMSELoss and SparseDistillKLDivLoss. -- Uses FlopsLoss for regularization to control sparsity by default, but supports custom regularizers. -- Balances effectiveness (via the main loss) with efficiency by regularizing both query and document representations. -- Allows using different regularizers for queries and documents via the `query_regularizer` and `document_regularizer` parameters, enabling fine-grained control over sparsity patterns for different types of inputs. -- Supports separate threshold values for queries and documents via the `query_regularizer_threshold` and `document_regularizer_threshold` parameters, allowing different sparsity strictness levels for each input type. +1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss, SparseMarginMSELoss and SparseDistillKLDivLoss commonly used. +2. Regularization loss: FlopsLoss is used to control sparsity, but supports custom regularizers. + - `query_regularizer` and `document_regularizer` can be set to any custom regularization loss. + - `query_regularizer_threshold` and `document_regularizer_threshold` can be set to control the sparsity strictness for queries and documents separately, setting the regularization loss to zero if an embedding has less than the threshold number of active (non-zero) dimensions. + +#### Cached SPLADE Loss + +The CachedSpladeLoss is a variant of the SPLADE loss adopting GradCache, which allows for much larger batch sizes without additional GPU memory usage. It achieves this by computing and caching loss gradients in mini-batches. + +Main losses that use in-batch negatives, primarily SparseMultipleNegativesRankingLoss, benefit heavily from larger batch sizes, as it results in more negatives and a stronger training signal. ### CSR Loss If you are using the SparseAutoEncoder module, then you have to use the CSRLoss (Contrastive Sparse Representation Loss). It combines two components: -- A reconstruction loss CSRReconstructionLoss that ensures sparse representation can faithfully reconstruct original embeddings. -- A main loss, which in the paper is a contrastive learning component using `SparseMultipleNegativesRankingLoss` that ensures semanticallysimilar sentences have similar representations. But it's theoretically possible to use all the losses mention below as main loss like for SpladeLoss . - +1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss used in the CSR Paper. +2. Reconstruction loss: CSRReconstructionLoss is used to ensure that sparse representation can faithfully reconstruct the original dense embeddings. ## Loss Table @@ -34,18 +38,16 @@ Loss functions play a critical role in the performance of your fine-tuned model. .. note:: You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, ``(sentence_A, sentence_B) pairs`` with ``class`` labels can be converted into ``(anchor, positive, negative) triplets`` by sampling sentences with the same or different classes. - - .. note:: - - The loss functions in `SentenceTransformer > Loss Overview <../sentence_transformer/loss_overview.html>`_ that appear here with the ``Sparse`` prefix are identical to their dense versions. The prefix is used only to indicate which losses can be used as main losses to train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder` ``` +**Legend:** Loss functions marked with `★` are commonly recommended default choices. + | Inputs | Labels | Appropriate Loss Functions | |---------------------------------------------------|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `(anchor, positive) pairs` | `none` | `SparseMultipleNegativesRankingLoss` | +| `(anchor, positive) pairs` | `none` | `SparseMultipleNegativesRankingLoss` ★ | | `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `SparseCoSENTLoss`
`SparseAnglELoss`
`SparseCosineSimilarityLoss` | -| `(anchor, positive, negative) triplets` | `none` | `SparseMultipleNegativesRankingLoss`
`SparseTripletLoss` | -| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `SparseMultipleNegativesRankingLoss` | +| `(anchor, positive, negative) triplets` | `none` | `SparseMultipleNegativesRankingLoss`
`SparseTripletLoss` | +| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `SparseMultipleNegativesRankingLoss` ★ | ## Distillation @@ -64,7 +66,7 @@ These loss functions are specifically designed to be used when distilling the kn In practice, not all loss functions get used equally often. The most common scenarios are: -* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss or CSRLoss, both typically using InfoNCE as their underlying loss function. +* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss, CachedSpladeLoss, or CSRLoss, both typically using InfoNCE as their underlying loss function. * `(query, positive, negative_1, ..., negative_n)` format: This structure with multiple negatives is particularly effective with SpladeLoss configured with SparseMarginMSELoss, especially in knowledge distillation scenarios where a teacher model provides similarity scores. The strongest models are trained with distillation losses like SparseDistillKLDivLoss or SparseMarginMSELoss. diff --git a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py index c97e7c323..ccd5787be 100644 --- a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py +++ b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py @@ -36,13 +36,6 @@ def __init__( Cached version of :class:`SpladeLoss` that uses the GradCache technique to allow for much larger effective batch sizes without additional GPU memory usage. - The key insight is that caching happens at the SpladeLoss (wrapper) level, not at the base loss level. - This resolves the fundamental conflict between: - - - **Cached losses** needing raw inputs to compute embeddings in mini-batches. - - **SpladeLoss** needing to compute embeddings once and share them with both the base loss and - the FLOPS regularizer. - By performing the GradCache mini-batch embedding at the SpladeLoss level, both the base loss and regularizers still receive pre-computed embeddings via ``compute_loss_from_embeddings()`` — no changes to base losses or regularizers are needed. From 623bcaabb47bd776a3ff28cf362b371b026788f0 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 11:42:06 +0100 Subject: [PATCH 05/12] Update the information retrieval docs to mention train_splade_nq_cached.py --- examples/sparse_encoder/training/retrievers/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/sparse_encoder/training/retrievers/README.md b/examples/sparse_encoder/training/retrievers/README.md index edd591680..3dc66a3ca 100644 --- a/examples/sparse_encoder/training/retrievers/README.md +++ b/examples/sparse_encoder/training/retrievers/README.md @@ -26,6 +26,12 @@ Example scripts could be: This example also uses :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` (similarly utilizing :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss`) and trains on the `NQ (natural questions) `_ dataset. It showcases an alternative configuration or approach for training SPLADE models on question-answering data for sparse retrieval. ``` +- **[train_splade_nq_cached.py](train_splade_nq_cached.py)**: + + ```{eval-rst} + This example is similar to the last one, but uses :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss` to get much larger batch sizes (e.g. 512 instead of 16) during training without increasing GPU memory usage. Because :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss` benefits greatly from larger batch sizes (more in-batch negatives), this results in better retrieval performance. + ``` + - **[train_csr_nq.py](train_csr_nq.py)**: ```{eval-rst} From 2c7d9da599f4404b1e93d5f66c276d12c5303e54 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Mon, 23 Feb 2026 22:14:13 +0900 Subject: [PATCH 06/12] Update CachedSpladeLoss import after SpladeLoss to avoid circular import --- sentence_transformers/sparse_encoder/losses/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/sparse_encoder/losses/__init__.py b/sentence_transformers/sparse_encoder/losses/__init__.py index 2f3aa317c..39d321cea 100644 --- a/sentence_transformers/sparse_encoder/losses/__init__.py +++ b/sentence_transformers/sparse_encoder/losses/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from .CachedSpladeLoss import CachedSpladeLoss from .CSRLoss import CSRLoss, CSRReconstructionLoss from .FlopsLoss import FlopsLoss from .SparseAnglELoss import SparseAnglELoss @@ -12,6 +11,7 @@ from .SparseMultipleNegativesRankingLoss import SparseMultipleNegativesRankingLoss from .SparseTripletLoss import SparseTripletLoss from .SpladeLoss import SpladeLoss +from .CachedSpladeLoss import CachedSpladeLoss # Must be after SpladeLoss to avoid circular import __all__ = [ "CachedSpladeLoss", From 8dd57eab939304d18195b82eaa72f0ee2cf68050 Mon Sep 17 00:00:00 2001 From: yjoonjang Date: Mon, 23 Feb 2026 22:20:56 +0900 Subject: [PATCH 07/12] Add CachedSpladeLoss training script for MS MARCO margin MSE distillation --- .../train_splade_msmarco_margin_mse_cached.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py diff --git a/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py b/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py new file mode 100644 index 000000000..372675d47 --- /dev/null +++ b/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py @@ -0,0 +1,148 @@ +""" +This scripts demonstrates how to train a Sparse Encoder model for Information Retrieval +using CachedSpladeLoss, which enables much larger batch sizes without additional GPU memory. + +As dataset, we use MSMARCO version with hard negatives from the bert-ensemble-margin-mse dataset. + +As loss function, we use MarginMSELoss in the CachedSpladeLoss. + +Usage: +python train_splade_msmarco_margin_mse_cached.py +""" + +import logging +import traceback + +from datasets import load_dataset + +from sentence_transformers import ( + SparseEncoder, + SparseEncoderModelCardData, + SparseEncoderTrainer, + SparseEncoderTrainingArguments, +) +from sentence_transformers.sparse_encoder import evaluation, losses + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) + + +def main(): + model_name = "distilbert/distilbert-base-uncased" + short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] + global_batch_size = 512 + mini_batch_size = 32 + + # 1a. Load a model to finetune with 1b. (Optional) model card data + model = SparseEncoder( + model_name, + model_card_data=SparseEncoderModelCardData( + language="en", + license="apache-2.0", + model_name=f"splade-{short_model_name} trained on MS MARCO hard negatives with distillation", + ), + ) + model.max_seq_length = 256 # Set the max sequence length to 256 for the training + logging.info("Model max length: %s", model.max_seq_length) + + # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/sentence-transformers/msmarco + dataset_size = 100_000 # We only use the first 100k samples for training + logging.info("The dataset has not been fully stored as texts on disk yet. We will do this now.") + corpus = load_dataset("sentence-transformers/msmarco", "corpus", split="train") + corpus = dict(zip(corpus["passage_id"], corpus["passage"])) + queries = load_dataset("sentence-transformers/msmarco", "queries", split="train") + queries = dict(zip(queries["query_id"], queries["query"])) + dataset = load_dataset("sentence-transformers/msmarco", "bert-ensemble-margin-mse", split="train") + dataset = dataset.select(range(dataset_size)) + + def id_to_text_map(batch): + return { + "query": [queries[qid] for qid in batch["query_id"]], + "positive": [corpus[pid] for pid in batch["positive_id"]], + "negative": [corpus[pid] for pid in batch["negative_id"]], + "score": batch["score"], + } + + dataset = dataset.map(id_to_text_map, batched=True, remove_columns=["query_id", "positive_id", "negative_id"]) + dataset = dataset.train_test_split(test_size=10_000) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + logging.info(train_dataset) + + # 3. Define our training loss. + query_regularizer_weight = 5e-5 + document_regularizer_weight = 3e-5 + + loss = losses.CachedSpladeLoss( + model=model, + loss=losses.SparseMarginMSELoss(model=model), + mini_batch_size=mini_batch_size, + query_regularizer_weight=query_regularizer_weight, + document_regularizer_weight=document_regularizer_weight, + ) + + # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator + evaluator = evaluation.SparseNanoBEIREvaluator( + dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size + ) + evaluator(model) + + # 5. Define the training arguments + run_name = f"splade-{short_model_name}-msmarco-hard-negatives-{global_batch_size}bs" + training_args = SparseEncoderTrainingArguments( + # Required parameter: + output_dir=f"models/{run_name}", + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=global_batch_size, + per_device_eval_batch_size=global_batch_size, + warmup_ratio=0.1, + learning_rate=2e-5, + fp16=False, # Set to False if you get an error that your GPU can't run on FP16 + bf16=True, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=0.2, + save_strategy="steps", + save_steps=0.2, + save_total_limit=2, + logging_steps=0.05, + run_name=run_name, # Will be used in W&B if `wandb` is installed + seed=42, + # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance + # load_best_model_at_end=True, + # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10", + ) + + # 6. Create the trainer & start training + trainer = SparseEncoderTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=loss, + evaluator=evaluator, + ) + trainer.train() + + # 7. Evaluate the final model again + evaluator(model) + + # 8. Save the final model + final_output_dir = f"models/{run_name}/final" + model.save_pretrained(final_output_dir) + + # 9. (Optional) save the model to the Hugging Face Hub! + # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first + try: + model.push_to_hub(run_name) + except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + ) + + +if __name__ == "__main__": + main() From 00c6c0cae62531586753c1c3f4dd1425ce9608df Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 15:30:02 +0100 Subject: [PATCH 08/12] Remove cached train file; cached not helpful for non-in-batch-negatives losses --- .../train_splade_msmarco_margin_mse_cached.py | 148 ------------------ 1 file changed, 148 deletions(-) delete mode 100644 examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py diff --git a/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py b/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py deleted file mode 100644 index 372675d47..000000000 --- a/examples/sparse_encoder/training/distillation/train_splade_msmarco_margin_mse_cached.py +++ /dev/null @@ -1,148 +0,0 @@ -""" -This scripts demonstrates how to train a Sparse Encoder model for Information Retrieval -using CachedSpladeLoss, which enables much larger batch sizes without additional GPU memory. - -As dataset, we use MSMARCO version with hard negatives from the bert-ensemble-margin-mse dataset. - -As loss function, we use MarginMSELoss in the CachedSpladeLoss. - -Usage: -python train_splade_msmarco_margin_mse_cached.py -""" - -import logging -import traceback - -from datasets import load_dataset - -from sentence_transformers import ( - SparseEncoder, - SparseEncoderModelCardData, - SparseEncoderTrainer, - SparseEncoderTrainingArguments, -) -from sentence_transformers.sparse_encoder import evaluation, losses - -# Set the log level to INFO to get more information -logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) - - -def main(): - model_name = "distilbert/distilbert-base-uncased" - short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] - global_batch_size = 512 - mini_batch_size = 32 - - # 1a. Load a model to finetune with 1b. (Optional) model card data - model = SparseEncoder( - model_name, - model_card_data=SparseEncoderModelCardData( - language="en", - license="apache-2.0", - model_name=f"splade-{short_model_name} trained on MS MARCO hard negatives with distillation", - ), - ) - model.max_seq_length = 256 # Set the max sequence length to 256 for the training - logging.info("Model max length: %s", model.max_seq_length) - - # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/sentence-transformers/msmarco - dataset_size = 100_000 # We only use the first 100k samples for training - logging.info("The dataset has not been fully stored as texts on disk yet. We will do this now.") - corpus = load_dataset("sentence-transformers/msmarco", "corpus", split="train") - corpus = dict(zip(corpus["passage_id"], corpus["passage"])) - queries = load_dataset("sentence-transformers/msmarco", "queries", split="train") - queries = dict(zip(queries["query_id"], queries["query"])) - dataset = load_dataset("sentence-transformers/msmarco", "bert-ensemble-margin-mse", split="train") - dataset = dataset.select(range(dataset_size)) - - def id_to_text_map(batch): - return { - "query": [queries[qid] for qid in batch["query_id"]], - "positive": [corpus[pid] for pid in batch["positive_id"]], - "negative": [corpus[pid] for pid in batch["negative_id"]], - "score": batch["score"], - } - - dataset = dataset.map(id_to_text_map, batched=True, remove_columns=["query_id", "positive_id", "negative_id"]) - dataset = dataset.train_test_split(test_size=10_000) - train_dataset = dataset["train"] - eval_dataset = dataset["test"] - logging.info(train_dataset) - - # 3. Define our training loss. - query_regularizer_weight = 5e-5 - document_regularizer_weight = 3e-5 - - loss = losses.CachedSpladeLoss( - model=model, - loss=losses.SparseMarginMSELoss(model=model), - mini_batch_size=mini_batch_size, - query_regularizer_weight=query_regularizer_weight, - document_regularizer_weight=document_regularizer_weight, - ) - - # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator - evaluator = evaluation.SparseNanoBEIREvaluator( - dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size - ) - evaluator(model) - - # 5. Define the training arguments - run_name = f"splade-{short_model_name}-msmarco-hard-negatives-{global_batch_size}bs" - training_args = SparseEncoderTrainingArguments( - # Required parameter: - output_dir=f"models/{run_name}", - # Optional training parameters: - num_train_epochs=1, - per_device_train_batch_size=global_batch_size, - per_device_eval_batch_size=global_batch_size, - warmup_ratio=0.1, - learning_rate=2e-5, - fp16=False, # Set to False if you get an error that your GPU can't run on FP16 - bf16=True, # Set to True if you have a GPU that supports BF16 - # Optional tracking/debugging parameters: - eval_strategy="steps", - eval_steps=0.2, - save_strategy="steps", - save_steps=0.2, - save_total_limit=2, - logging_steps=0.05, - run_name=run_name, # Will be used in W&B if `wandb` is installed - seed=42, - # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance - # load_best_model_at_end=True, - # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10", - ) - - # 6. Create the trainer & start training - trainer = SparseEncoderTrainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - loss=loss, - evaluator=evaluator, - ) - trainer.train() - - # 7. Evaluate the final model again - evaluator(model) - - # 8. Save the final model - final_output_dir = f"models/{run_name}/final" - model.save_pretrained(final_output_dir) - - # 9. (Optional) save the model to the Hugging Face Hub! - # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first - try: - model.push_to_hub(run_name) - except Exception: - logging.error( - f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` " - f"and saving it using `model.push_to_hub('{run_name}')`." - ) - - -if __name__ == "__main__": - main() From 5c8cb900a50b2ae1ff6325ece3387cf55a725f24 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 15:30:44 +0100 Subject: [PATCH 09/12] Add isort: skip to CachedSpladeLoss to fix import order issue --- sentence_transformers/sparse_encoder/losses/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/sparse_encoder/losses/__init__.py b/sentence_transformers/sparse_encoder/losses/__init__.py index 39d321cea..174b67055 100644 --- a/sentence_transformers/sparse_encoder/losses/__init__.py +++ b/sentence_transformers/sparse_encoder/losses/__init__.py @@ -11,7 +11,8 @@ from .SparseMultipleNegativesRankingLoss import SparseMultipleNegativesRankingLoss from .SparseTripletLoss import SparseTripletLoss from .SpladeLoss import SpladeLoss -from .CachedSpladeLoss import CachedSpladeLoss # Must be after SpladeLoss to avoid circular import + +from .CachedSpladeLoss import CachedSpladeLoss # isort: skip # Avoid circular import with SpladeLoss -> FlopsLoss __all__ = [ "CachedSpladeLoss", From 9504cf9d4fab938fdb52eedd419175b97ae3e39c Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 23 Feb 2026 15:37:35 +0100 Subject: [PATCH 10/12] both -> all Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/sparse_encoder/loss_overview.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sparse_encoder/loss_overview.md b/docs/sparse_encoder/loss_overview.md index 776a54b65..e7978ad3b 100644 --- a/docs/sparse_encoder/loss_overview.md +++ b/docs/sparse_encoder/loss_overview.md @@ -66,7 +66,7 @@ These loss functions are specifically designed to be used when distilling the kn In practice, not all loss functions get used equally often. The most common scenarios are: -* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss, CachedSpladeLoss, or CSRLoss, both typically using InfoNCE as their underlying loss function. +* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss, CachedSpladeLoss, or CSRLoss, all typically using InfoNCE as their underlying loss function. * `(query, positive, negative_1, ..., negative_n)` format: This structure with multiple negatives is particularly effective with SpladeLoss configured with SparseMarginMSELoss, especially in knowledge distillation scenarios where a teacher model provides similarity scores. The strongest models are trained with distillation losses like SparseDistillKLDivLoss or SparseMarginMSELoss. From fec126796c7156650f013b81077739c546fdff3c Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 15:42:16 +0100 Subject: [PATCH 11/12] Replace corpus with document in both SpladeLoss classes --- .../sparse_encoder/losses/CachedSpladeLoss.py | 18 +++++++++--------- .../sparse_encoder/losses/SpladeLoss.py | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py index ccd5787be..bf98ce924 100644 --- a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py +++ b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py @@ -52,18 +52,18 @@ def __init__( model: SparseEncoder model loss: The principal loss function to use can be any of the SparseEncoder losses except CSR related losses and flops loss. Must have a ``compute_loss_from_embeddings`` method. - document_regularizer_weight: Weight for the corpus regularization term. This term encourages sparsity + document_regularizer_weight: Weight for the document regularization term. This term encourages sparsity in the document embeddings. In some papers, this parameter is referred to as "lambda_d" (document) or "lambda_c" (corpus). query_regularizer_weight: Weight for the query regularization term. This term encourages sparsity in the query embeddings. If None, no query regularization will be applied. In some papers, this parameter is referred to as "lambda_q" (query). - document_regularizer: Optional regularizer to use specifically for corpus regularization instead of the + document_regularizer: Optional regularizer to use specifically for document regularization instead of the default FlopsLoss. query_regularizer: Optional regularizer to use specifically for query regularization instead of the default FlopsLoss. document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the - corpus embeddings to be considered in the FlopsLoss. + document embeddings to be considered in the FlopsLoss. query_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the query embeddings to be considered in the FlopsLoss. use_document_regularizer_only: If True, all input embeddings are treated as documents and regularized @@ -201,10 +201,10 @@ def _compute_total_loss( # Document regularizer if self.use_document_regularizer_only: - corpus_emb = torch.cat(embeddings) + document_emb = torch.cat(embeddings) else: - corpus_emb = torch.cat(embeddings[1:]) - doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(corpus_emb) + document_emb = torch.cat(embeddings[1:]) + doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(document_emb) weighted_doc_reg = doc_reg_loss * self.document_regularizer_weight self._doc_reg_value = weighted_doc_reg.detach().item() total_loss = total_loss + weighted_doc_reg @@ -284,10 +284,10 @@ def forward( losses["base_loss"] = base_loss if self.use_document_regularizer_only: - corpus_emb = torch.cat(embeddings) + document_emb = torch.cat(embeddings) else: - corpus_emb = torch.cat(embeddings[1:]) - doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(corpus_emb) + document_emb = torch.cat(embeddings[1:]) + doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(document_emb) losses["document_regularizer_loss"] = doc_reg_loss * self.document_regularizer_weight if self.query_regularizer_weight is not None: diff --git a/sentence_transformers/sparse_encoder/losses/SpladeLoss.py b/sentence_transformers/sparse_encoder/losses/SpladeLoss.py index 10751e585..fd0356009 100644 --- a/sentence_transformers/sparse_encoder/losses/SpladeLoss.py +++ b/sentence_transformers/sparse_encoder/losses/SpladeLoss.py @@ -36,19 +36,19 @@ def __init__( Args: model: SparseEncoder model loss: The principal loss function to use can be any of the SparseEncoder losses except CSR related losses and flops loss. - document_regularizer_weight: Weight for the corpus regularization term. This term encourages sparsity in the document embeddings. + document_regularizer_weight: Weight for the document regularization term. This term encourages sparsity in the document embeddings. Will be applied to positive documents and all negatives one if some are provided. In some papers, this parameter is referred to as "lambda_d" (document) or "lambda_c" (corpus). query_regularizer_weight: Weight for the query regularization term. This term encourages sparsity in the query embeddings. If None, no query regularization will be applied, it's not a problem if you are in an inference-free setup or if you are having use_document_regularizer_only=True. Else you should have a query_regularizer_weight > 0. In some papers, this parameter is referred to as "lambda_q" (query). - document_regularizer: Optional regularizer to use specifically for corpus regularization instead of the default FlopsLoss. + document_regularizer: Optional regularizer to use specifically for document regularization instead of the default FlopsLoss. This allows for different regularization strategies for documents vs queries. query_regularizer: Optional regularizer to use specifically for query regularization instead of the default FlopsLoss. This allows for different regularization strategies for queries vs documents. - document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the corpus embeddings to be considered in the FlopsLoss. - If specified, only corpus embeddings with more than this number of non-zero (active) elements will be considered. + document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the document embeddings to be considered in the FlopsLoss. + If specified, only document embeddings with more than this number of non-zero (active) elements will be considered. Only used when document_regularizer is None (for the default FlopsLoss). query_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the query embeddings to be considered in the FlopsLoss. If specified, only query embeddings with more than this number of non-zero (active) elements will be considered. @@ -149,10 +149,10 @@ def forward( if self.use_document_regularizer_only: # If use_document_regularizer_only is True, we consider all the input to be of the same type and so under the same regularization - corpus_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings)) + document_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings)) else: - corpus_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings[1:])) - losses["document_regularizer_loss"] = corpus_loss * self.document_regularizer_weight + document_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings[1:])) + losses["document_regularizer_loss"] = document_loss * self.document_regularizer_weight # Add query regularization if enabled if self.query_regularizer_weight is not None: From 35ccace15a941c732ad19c93fce62d9047fd8f07 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Mon, 23 Feb 2026 15:42:24 +0100 Subject: [PATCH 12/12] Document the higher LR for the Cached training --- .../training/retrievers/train_splade_nq_cached.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py index 5b87f0c9e..3f44f1d03 100644 --- a/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py +++ b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py @@ -82,7 +82,7 @@ def main(): per_device_train_batch_size=global_batch_size, per_device_eval_batch_size=global_batch_size, warmup_ratio=0.1, - learning_rate=1e-5, + learning_rate=1e-5, # NOTE: The learning rate is much higher to account for the higher batch size fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch