diff --git a/docs/package_reference/sparse_encoder/losses.md b/docs/package_reference/sparse_encoder/losses.md
index 95b0a3961..d7048c54f 100644
--- a/docs/package_reference/sparse_encoder/losses.md
+++ b/docs/package_reference/sparse_encoder/losses.md
@@ -15,6 +15,11 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui
.. autoclass:: sentence_transformers.sparse_encoder.losses.SpladeLoss
```
+## CachedSpladeLoss
+```{eval-rst}
+.. autoclass:: sentence_transformers.sparse_encoder.losses.CachedSpladeLoss
+```
+
## FlopsLoss
```{eval-rst}
.. autoclass:: sentence_transformers.sparse_encoder.losses.FlopsLoss
diff --git a/docs/sparse_encoder/loss_overview.md b/docs/sparse_encoder/loss_overview.md
index d345a6b17..e7978ad3b 100644
--- a/docs/sparse_encoder/loss_overview.md
+++ b/docs/sparse_encoder/loss_overview.md
@@ -2,7 +2,7 @@
```{eval-rst}
.. warning::
- To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding.
+ To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss`, :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss`, or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding.
```
@@ -10,21 +10,25 @@
### SPLADE Loss
-The SpladeLoss implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to control efficiency:
+The SpladeLoss implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to balance effectiveness and efficiency:
-- Supports all the losses mention below as main loss but three principal loss types: SparseMultipleNegativesRankingLoss, SparseMarginMSELoss and SparseDistillKLDivLoss.
-- Uses FlopsLoss for regularization to control sparsity by default, but supports custom regularizers.
-- Balances effectiveness (via the main loss) with efficiency by regularizing both query and document representations.
-- Allows using different regularizers for queries and documents via the `query_regularizer` and `document_regularizer` parameters, enabling fine-grained control over sparsity patterns for different types of inputs.
-- Supports separate threshold values for queries and documents via the `query_regularizer_threshold` and `document_regularizer_threshold` parameters, allowing different sparsity strictness levels for each input type.
+1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss, SparseMarginMSELoss and SparseDistillKLDivLoss commonly used.
+2. Regularization loss: FlopsLoss is used to control sparsity, but supports custom regularizers.
+ - `query_regularizer` and `document_regularizer` can be set to any custom regularization loss.
+ - `query_regularizer_threshold` and `document_regularizer_threshold` can be set to control the sparsity strictness for queries and documents separately, setting the regularization loss to zero if an embedding has less than the threshold number of active (non-zero) dimensions.
+
+#### Cached SPLADE Loss
+
+The CachedSpladeLoss is a variant of the SPLADE loss adopting GradCache, which allows for much larger batch sizes without additional GPU memory usage. It achieves this by computing and caching loss gradients in mini-batches.
+
+Main losses that use in-batch negatives, primarily SparseMultipleNegativesRankingLoss, benefit heavily from larger batch sizes, as it results in more negatives and a stronger training signal.
### CSR Loss
If you are using the SparseAutoEncoder module, then you have to use the CSRLoss (Contrastive Sparse Representation Loss). It combines two components:
-- A reconstruction loss CSRReconstructionLoss that ensures sparse representation can faithfully reconstruct original embeddings.
-- A main loss, which in the paper is a contrastive learning component using `SparseMultipleNegativesRankingLoss` that ensures semanticallysimilar sentences have similar representations. But it's theoretically possible to use all the losses mention below as main loss like for SpladeLoss .
-
+1. Main loss: Supports all the losses from the Loss Table and Distillation, with SparseMultipleNegativesRankingLoss used in the CSR Paper.
+2. Reconstruction loss: CSRReconstructionLoss is used to ensure that sparse representation can faithfully reconstruct the original dense embeddings.
## Loss Table
@@ -34,18 +38,16 @@ Loss functions play a critical role in the performance of your fine-tuned model.
.. note::
You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, ``(sentence_A, sentence_B) pairs`` with ``class`` labels can be converted into ``(anchor, positive, negative) triplets`` by sampling sentences with the same or different classes.
-
- .. note::
-
- The loss functions in `SentenceTransformer > Loss Overview <../sentence_transformer/loss_overview.html>`_ that appear here with the ``Sparse`` prefix are identical to their dense versions. The prefix is used only to indicate which losses can be used as main losses to train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`
```
+**Legend:** Loss functions marked with `★` are commonly recommended default choices.
+
| Inputs | Labels | Appropriate Loss Functions |
|---------------------------------------------------|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| `(anchor, positive) pairs` | `none` | `SparseMultipleNegativesRankingLoss` |
+| `(anchor, positive) pairs` | `none` | `SparseMultipleNegativesRankingLoss` ★ |
| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | `SparseCoSENTLoss`
`SparseAnglELoss`
`SparseCosineSimilarityLoss` |
-| `(anchor, positive, negative) triplets` | `none` | `SparseMultipleNegativesRankingLoss`
`SparseTripletLoss` |
-| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `SparseMultipleNegativesRankingLoss` |
+| `(anchor, positive, negative) triplets` | `none` | `SparseMultipleNegativesRankingLoss` ★
`SparseTripletLoss` |
+| `(anchor, positive, negative_1, ..., negative_n)` | `none` | `SparseMultipleNegativesRankingLoss` ★ |
## Distillation
@@ -64,7 +66,7 @@ These loss functions are specifically designed to be used when distilling the kn
In practice, not all loss functions get used equally often. The most common scenarios are:
-* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss or CSRLoss, both typically using InfoNCE as their underlying loss function.
+* `(anchor, positive) pairs` without any labels: SparseMultipleNegativesRankingLoss (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with SpladeLoss, CachedSpladeLoss, or CSRLoss, all typically using InfoNCE as their underlying loss function.
* `(query, positive, negative_1, ..., negative_n)` format: This structure with multiple negatives is particularly effective with SpladeLoss configured with SparseMarginMSELoss, especially in knowledge distillation scenarios where a teacher model provides similarity scores. The strongest models are trained with distillation losses like SparseDistillKLDivLoss or SparseMarginMSELoss.
diff --git a/examples/sparse_encoder/training/retrievers/README.md b/examples/sparse_encoder/training/retrievers/README.md
index edd591680..3dc66a3ca 100644
--- a/examples/sparse_encoder/training/retrievers/README.md
+++ b/examples/sparse_encoder/training/retrievers/README.md
@@ -26,6 +26,12 @@ Example scripts could be:
This example also uses :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` (similarly utilizing :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss`) and trains on the `NQ (natural questions) `_ dataset. It showcases an alternative configuration or approach for training SPLADE models on question-answering data for sparse retrieval.
```
+- **[train_splade_nq_cached.py](train_splade_nq_cached.py)**:
+
+ ```{eval-rst}
+ This example is similar to the last one, but uses :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss` to get much larger batch sizes (e.g. 512 instead of 16) during training without increasing GPU memory usage. Because :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss` benefits greatly from larger batch sizes (more in-batch negatives), this results in better retrieval performance.
+ ```
+
- **[train_csr_nq.py](train_csr_nq.py)**:
```{eval-rst}
diff --git a/examples/sparse_encoder/training/retrievers/train_splade_nq.py b/examples/sparse_encoder/training/retrievers/train_splade_nq.py
index 3e754fcb8..2ea0458e8 100644
--- a/examples/sparse_encoder/training/retrievers/train_splade_nq.py
+++ b/examples/sparse_encoder/training/retrievers/train_splade_nq.py
@@ -29,9 +29,8 @@
def main():
model_name = "distilbert/distilbert-base-uncased"
-
- train_batch_size = 12
- num_epochs = 1
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
+ global_batch_size = 16
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = SparseEncoder(
@@ -39,7 +38,7 @@ def main():
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
- model_name="splade-distilbert-base-uncased trained on Natural Questions",
+ model_name=f"splade-{short_model_name} trained on Natural Questions",
),
)
model.max_seq_length = 256 # Set the max sequence length to 256 for the training
@@ -67,34 +66,35 @@ def main():
# 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
evaluator = evaluation.SparseNanoBEIREvaluator(
- dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=train_batch_size
+ dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=global_batch_size
)
+ evaluator(model)
# 5. Define the training arguments
- short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"splade-{short_model_name}-nq"
training_args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
- num_train_epochs=num_epochs,
- per_device_train_batch_size=train_batch_size,
- per_device_eval_batch_size=train_batch_size,
- learning_rate=2e-5,
+ num_train_epochs=1,
+ per_device_train_batch_size=global_batch_size,
+ per_device_eval_batch_size=global_batch_size,
+ warmup_ratio=0.1,
+ learning_rate=2e-6,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
- load_best_model_at_end=True,
- metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
- eval_steps=1650,
+ eval_steps=0.2,
save_strategy="steps",
- save_steps=1650,
+ save_steps=0.2,
save_total_limit=2,
- logging_steps=200,
+ logging_steps=0.05,
run_name=run_name, # Will be used in W&B if `wandb` is installed
- seed=42,
+ # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance
+ # load_best_model_at_end=True,
+ # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
)
# 6. Create the trainer & start training
@@ -108,9 +108,8 @@ def main():
)
trainer.train()
- # 7. Evaluate the final model, using the complete NanoBEIR dataset
- test_evaluator = evaluation.SparseNanoBEIREvaluator(show_progress_bar=True, batch_size=train_batch_size)
- test_evaluator(model)
+ # 7. Evaluate the final model again
+ evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
diff --git a/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py
new file mode 100644
index 000000000..3f44f1d03
--- /dev/null
+++ b/examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py
@@ -0,0 +1,133 @@
+"""
+This example trains a SparseEncoder for the Natural Questions (NQ) dataset.
+The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
+It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets,
+and trains the model as a retriever. After training, the model is evaluated and saved locally,
+with an optional step to push the trained model to the Hugging Face Hub.
+
+Usage:
+python train_splade_nq.py
+"""
+
+import logging
+import traceback
+
+from datasets import load_dataset
+
+from sentence_transformers import (
+ SparseEncoder,
+ SparseEncoderModelCardData,
+ SparseEncoderTrainer,
+ SparseEncoderTrainingArguments,
+)
+from sentence_transformers.sparse_encoder import evaluation, losses
+from sentence_transformers.training_args import BatchSamplers
+
+# Set the log level to INFO to get more information
+logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
+
+
+def main():
+ model_name = "distilbert/distilbert-base-uncased"
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
+ global_batch_size = 512
+ mini_batch_size = 32
+
+ # 1a. Load a model to finetune with 1b. (Optional) model card data
+ model = SparseEncoder(
+ model_name,
+ model_card_data=SparseEncoderModelCardData(
+ language="en",
+ license="apache-2.0",
+ model_name=f"splade-{short_model_name} trained on Natural Questions",
+ ),
+ )
+ model.max_seq_length = 256 # Set the max sequence length to 256 for the training
+ logging.info("Model max length: %s", model.max_seq_length)
+
+ # 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
+ logging.info("Read the Natural Questions training dataset")
+ full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
+ train_dataset = dataset_dict["train"]
+ eval_dataset = dataset_dict["test"]
+ logging.info(train_dataset)
+ logging.info(eval_dataset)
+
+ # 3. Define our training loss.
+ query_regularizer_weight = 5e-5
+ document_regularizer_weight = 3e-5
+
+ loss = losses.CachedSpladeLoss(
+ model=model,
+ loss=losses.SparseMultipleNegativesRankingLoss(model=model),
+ mini_batch_size=mini_batch_size,
+ query_regularizer_weight=query_regularizer_weight, # Weight for query loss
+ document_regularizer_weight=document_regularizer_weight, # Weight for document loss
+ )
+
+ # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
+ evaluator = evaluation.SparseNanoBEIREvaluator(
+ dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size
+ )
+ evaluator(model)
+
+ # 5. Define the training arguments
+ run_name = f"splade-{short_model_name}-nq-512bs"
+ training_args = SparseEncoderTrainingArguments(
+ # Required parameter:
+ output_dir=f"models/{run_name}",
+ # Optional training parameters:
+ num_train_epochs=1,
+ per_device_train_batch_size=global_batch_size,
+ per_device_eval_batch_size=global_batch_size,
+ warmup_ratio=0.1,
+ learning_rate=1e-5, # NOTE: The learning rate is much higher to account for the higher batch size
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
+ bf16=True, # Set to True if you have a GPU that supports BF16
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
+ # Optional tracking/debugging parameters:
+ eval_strategy="steps",
+ eval_steps=0.2,
+ save_strategy="steps",
+ save_steps=0.2,
+ save_total_limit=2,
+ logging_steps=0.05,
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
+ # Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance
+ # load_best_model_at_end=True,
+ # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
+ )
+
+ # 6. Create the trainer & start training
+ trainer = SparseEncoderTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ loss=loss,
+ evaluator=evaluator,
+ )
+ trainer.train()
+
+ # 7. Evaluate the final model again
+ evaluator(model)
+
+ # 8. Save the final model
+ final_output_dir = f"models/{run_name}/final"
+ model.save_pretrained(final_output_dir)
+
+ # 9. (Optional) save the model to the Hugging Face Hub!
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
+ try:
+ model.push_to_hub(run_name)
+ except Exception:
+ logging.error(
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
+ f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` "
+ f"and saving it using `model.push_to_hub('{run_name}')`."
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py
new file mode 100644
index 000000000..bf98ce924
--- /dev/null
+++ b/sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py
@@ -0,0 +1,323 @@
+from __future__ import annotations
+
+import logging
+from collections.abc import Iterable, Iterator
+from contextlib import nullcontext
+from functools import partial
+from typing import Any
+
+import torch
+import tqdm
+from torch import Tensor, nn
+
+from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext, _backward_hook
+from sentence_transformers.sparse_encoder.losses.SpladeLoss import SpladeLoss
+from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder
+
+logger = logging.getLogger(__name__)
+
+
+class CachedSpladeLoss(SpladeLoss):
+ def __init__(
+ self,
+ model: SparseEncoder,
+ loss: nn.Module,
+ document_regularizer_weight: float,
+ query_regularizer_weight: float | None = None,
+ document_regularizer: nn.Module | None = None,
+ query_regularizer: nn.Module | None = None,
+ document_regularizer_threshold: int | None = None,
+ query_regularizer_threshold: int | None = None,
+ use_document_regularizer_only: bool = False,
+ mini_batch_size: int = 32,
+ show_progress_bar: bool = False,
+ ):
+ """
+ Cached version of :class:`SpladeLoss` that uses the GradCache technique to allow for much larger
+ effective batch sizes without additional GPU memory usage.
+
+ By performing the GradCache mini-batch embedding at the SpladeLoss level, both the base loss and
+ regularizers still receive pre-computed embeddings via ``compute_loss_from_embeddings()`` — no
+ changes to base losses or regularizers are needed.
+
+ In detail, the GradCache technique works as follows:
+
+ (1) A quick embedding step without gradients/computation graphs to get all embeddings in mini-batches;
+ (2) Calculate the combined loss (base + regularizers), backward up to the embeddings and cache the
+ gradients w.r.t. the embeddings;
+ (3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the
+ backward chain.
+
+ Args:
+ model: SparseEncoder model
+ loss: The principal loss function to use can be any of the SparseEncoder losses except CSR related
+ losses and flops loss. Must have a ``compute_loss_from_embeddings`` method.
+ document_regularizer_weight: Weight for the document regularization term. This term encourages sparsity
+ in the document embeddings. In some papers, this parameter is referred to as "lambda_d" (document)
+ or "lambda_c" (corpus).
+ query_regularizer_weight: Weight for the query regularization term. This term encourages sparsity in
+ the query embeddings. If None, no query regularization will be applied. In some papers, this
+ parameter is referred to as "lambda_q" (query).
+ document_regularizer: Optional regularizer to use specifically for document regularization instead of the
+ default FlopsLoss.
+ query_regularizer: Optional regularizer to use specifically for query regularization instead of the
+ default FlopsLoss.
+ document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the
+ document embeddings to be considered in the FlopsLoss.
+ query_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the
+ query embeddings to be considered in the FlopsLoss.
+ use_document_regularizer_only: If True, all input embeddings are treated as documents and regularized
+ together with document_regularizer_weight.
+ mini_batch_size: Mini-batch size for the forward pass, this denotes how much memory is actually used
+ during training and evaluation. The larger the mini-batch size, the more memory efficient the
+ training is, but the slower the training will be. It's recommended to set it as high as your GPU
+ memory allows. The default value is 32.
+ show_progress_bar: If True, a progress bar for the mini-batches is shown during training.
+
+ References:
+ - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup:
+ https://huggingface.co/papers/2101.06983
+ - From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective:
+ https://huggingface.co/papers/2205.04733
+
+ Requirements:
+ 1. Input requirements depend on the chosen loss
+ 2. Should be used with large ``per_device_train_batch_size`` and low ``mini_batch_size`` for superior
+ performance, but slower training time than :class:`SpladeLoss`.
+
+ Example:
+ ::
+
+ from datasets import Dataset
+
+ from sentence_transformers.sparse_encoder import SparseEncoder, SparseEncoderTrainer, losses
+
+ model = SparseEncoder("distilbert/distilbert-base-uncased")
+ train_dataset = Dataset.from_dict(
+ {
+ "anchor": ["It's nice weather outside today.", "He drove to work."],
+ "positive": ["It's so sunny.", "He took the car to the office."],
+ }
+ )
+ loss = losses.CachedSpladeLoss(
+ model=model,
+ loss=losses.SparseMultipleNegativesRankingLoss(model),
+ document_regularizer_weight=3e-5,
+ query_regularizer_weight=5e-5,
+ mini_batch_size=32,
+ )
+
+ trainer = SparseEncoderTrainer(model=model, train_dataset=train_dataset, loss=loss)
+ trainer.train()
+ """
+ super().__init__(
+ model=model,
+ loss=loss,
+ document_regularizer_weight=document_regularizer_weight,
+ query_regularizer_weight=query_regularizer_weight,
+ document_regularizer=document_regularizer,
+ query_regularizer=query_regularizer,
+ document_regularizer_threshold=document_regularizer_threshold,
+ query_regularizer_threshold=query_regularizer_threshold,
+ use_document_regularizer_only=use_document_regularizer_only,
+ )
+ self.mini_batch_size = mini_batch_size
+ self.show_progress_bar = show_progress_bar
+ self.cache: list[list[Tensor]] | None = None
+ self.random_states: list[list[RandContext]] | None = None
+
+ def embed_minibatch(
+ self,
+ sentence_feature: dict[str, Tensor],
+ begin: int,
+ end: int,
+ with_grad: bool,
+ copy_random_state: bool,
+ random_state: RandContext | None = None,
+ ) -> tuple[Tensor, RandContext | None]:
+ """Embed a mini-batch of sentences."""
+ grad_context = nullcontext if with_grad else torch.no_grad
+ random_state_context = nullcontext() if random_state is None else random_state
+ sentence_feature_minibatch = {
+ key: value[begin:end] if isinstance(value, torch.Tensor) else value
+ for key, value in sentence_feature.items()
+ }
+ with random_state_context:
+ with grad_context():
+ random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
+ reps = self.model(sentence_feature_minibatch)["sentence_embedding"]
+ return reps, random_state
+
+ def embed_minibatch_iter(
+ self,
+ sentence_feature: dict[str, Tensor],
+ with_grad: bool,
+ copy_random_state: bool,
+ random_states: list[RandContext] | None = None,
+ ) -> Iterator[tuple[Tensor, RandContext | None]]:
+ """Iterate over mini-batches of sentences for embedding."""
+ input_ids: Tensor = sentence_feature["input_ids"]
+ batch_size = input_ids.shape[0]
+ for i, begin in enumerate(
+ tqdm.trange(
+ 0,
+ batch_size,
+ self.mini_batch_size,
+ desc="Embed mini-batches",
+ disable=not self.show_progress_bar,
+ )
+ ):
+ end = begin + self.mini_batch_size
+ reps, random_state = self.embed_minibatch(
+ sentence_feature=sentence_feature,
+ begin=begin,
+ end=end,
+ with_grad=with_grad,
+ copy_random_state=copy_random_state,
+ random_state=None if random_states is None else random_states[i],
+ )
+ yield reps, random_state
+
+ def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], labels: Tensor | None) -> Tensor:
+ """Calculate the combined loss (base + regularizers) and cache gradients w.r.t. the embeddings."""
+ loss = self._compute_total_loss(reps, labels, with_backward=True)
+ loss = loss.detach().requires_grad_()
+ self.cache = [[r.grad for r in rs] for rs in reps]
+ return loss
+
+ def _compute_total_loss(
+ self, reps: list[list[Tensor]], labels: Tensor | None, with_backward: bool = False
+ ) -> Tensor:
+ """Compute total loss from base loss + regularizers on mini-batch reps."""
+ embeddings = [torch.cat(r) for r in reps]
+
+ # Base loss
+ base_loss = self.loss.compute_loss_from_embeddings(embeddings, labels)
+ if isinstance(base_loss, dict):
+ total_loss = sum(base_loss.values())
+ else:
+ total_loss = base_loss
+ self._base_loss_value = total_loss.detach().item()
+
+ # Document regularizer
+ if self.use_document_regularizer_only:
+ document_emb = torch.cat(embeddings)
+ else:
+ document_emb = torch.cat(embeddings[1:])
+ doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(document_emb)
+ weighted_doc_reg = doc_reg_loss * self.document_regularizer_weight
+ self._doc_reg_value = weighted_doc_reg.detach().item()
+ total_loss = total_loss + weighted_doc_reg
+
+ # Query regularizer
+ if self.query_regularizer_weight is not None:
+ query_reg_loss = self.query_regularizer.compute_loss_from_embeddings(embeddings[0])
+ weighted_query_reg = query_reg_loss * self.query_regularizer_weight
+ self._query_reg_value = weighted_query_reg.detach().item()
+ total_loss = total_loss + weighted_query_reg
+ else:
+ self._query_reg_value = None
+
+ if with_backward:
+ total_loss.backward()
+ total_loss = total_loss.detach()
+
+ return total_loss
+
+ def forward(
+ self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor | None = None
+ ) -> dict[str, Tensor] | Tensor:
+ sentence_features = list(sentence_features)
+
+ # Step (1): Embed all mini-batches without gradients to get all embeddings
+ reps = []
+ self.random_states = []
+ for sentence_feature in sentence_features:
+ reps_mbs = []
+ random_state_mbs = []
+ for reps_mb, random_state in self.embed_minibatch_iter(
+ sentence_feature=sentence_feature,
+ with_grad=False,
+ copy_random_state=True,
+ ):
+ reps_mbs.append(reps_mb.detach().requires_grad_())
+ random_state_mbs.append(random_state)
+ reps.append(reps_mbs)
+ self.random_states.append(random_state_mbs)
+
+ if torch.is_grad_enabled():
+ # Step (2): Calculate the combined loss, backward to embeddings and cache gradients
+ loss = self.calculate_loss_and_cache_gradients(reps, labels)
+
+ # Step (3): Register backward hook to chain cached gradients back through the model
+ loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
+
+ # Build dict for loss component logging while preserving gradient flow through `loss`.
+ # The trainer sums all dict values for backward, so we put the gradient-carrying tensor
+ # in "base_loss" and use detached tensors for regularizer entries.
+ device = loss.device
+ result = {}
+ non_base_sum = torch.tensor(0.0, device=device)
+
+ doc_reg_detached = torch.tensor(self._doc_reg_value, device=device)
+ result["document_regularizer_loss"] = doc_reg_detached
+ non_base_sum = non_base_sum + doc_reg_detached
+
+ if self._query_reg_value is not None:
+ query_reg_detached = torch.tensor(self._query_reg_value, device=device)
+ result["query_regularizer_loss"] = query_reg_detached
+ non_base_sum = non_base_sum + query_reg_detached
+
+ # base_loss = loss - regularizers, so trainer's sum(values) = loss (exact gradient flow)
+ result["base_loss"] = loss - non_base_sum
+
+ return result
+ else:
+ # Eval mode: no caching needed, compute losses directly
+ embeddings = [torch.cat(r) for r in reps]
+ losses = {}
+
+ base_loss = self.loss.compute_loss_from_embeddings(embeddings, labels)
+ if isinstance(base_loss, dict):
+ losses.update(base_loss)
+ else:
+ losses["base_loss"] = base_loss
+
+ if self.use_document_regularizer_only:
+ document_emb = torch.cat(embeddings)
+ else:
+ document_emb = torch.cat(embeddings[1:])
+ doc_reg_loss = self.document_regularizer.compute_loss_from_embeddings(document_emb)
+ losses["document_regularizer_loss"] = doc_reg_loss * self.document_regularizer_weight
+
+ if self.query_regularizer_weight is not None:
+ query_reg_loss = self.query_regularizer.compute_loss_from_embeddings(embeddings[0])
+ losses["query_regularizer_loss"] = query_reg_loss * self.query_regularizer_weight
+
+ return losses
+
+ def get_config_dict(self) -> dict[str, Any]:
+ config = super().get_config_dict()
+ config["mini_batch_size"] = self.mini_batch_size
+ return config
+
+ @property
+ def citation(self) -> str:
+ return """
+@misc{gao2021scaling,
+ title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
+ author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
+ year={2021},
+ eprint={2101.06983},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG}
+}
+@misc{formal2022distillationhardnegativesampling,
+ title={From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective},
+ author={Thibault Formal and Carlos Lassance and Benjamin Piwowarski and St\\'ephane Clinchant},
+ year={2022},
+ eprint={2205.04733},
+ archivePrefix={arXiv},
+ primaryClass={cs.IR},
+}
+"""
diff --git a/sentence_transformers/sparse_encoder/losses/SpladeLoss.py b/sentence_transformers/sparse_encoder/losses/SpladeLoss.py
index 10751e585..fd0356009 100644
--- a/sentence_transformers/sparse_encoder/losses/SpladeLoss.py
+++ b/sentence_transformers/sparse_encoder/losses/SpladeLoss.py
@@ -36,19 +36,19 @@ def __init__(
Args:
model: SparseEncoder model
loss: The principal loss function to use can be any of the SparseEncoder losses except CSR related losses and flops loss.
- document_regularizer_weight: Weight for the corpus regularization term. This term encourages sparsity in the document embeddings.
+ document_regularizer_weight: Weight for the document regularization term. This term encourages sparsity in the document embeddings.
Will be applied to positive documents and all negatives one if some are provided. In some papers, this parameter is
referred to as "lambda_d" (document) or "lambda_c" (corpus).
query_regularizer_weight: Weight for the query regularization term. This term encourages sparsity in the query embeddings.
If None, no query regularization will be applied, it's not a problem if you are in an inference-free setup or
if you are having use_document_regularizer_only=True. Else you should have a query_regularizer_weight > 0.
In some papers, this parameter is referred to as "lambda_q" (query).
- document_regularizer: Optional regularizer to use specifically for corpus regularization instead of the default FlopsLoss.
+ document_regularizer: Optional regularizer to use specifically for document regularization instead of the default FlopsLoss.
This allows for different regularization strategies for documents vs queries.
query_regularizer: Optional regularizer to use specifically for query regularization instead of the default FlopsLoss.
This allows for different regularization strategies for queries vs documents.
- document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the corpus embeddings to be considered in the FlopsLoss.
- If specified, only corpus embeddings with more than this number of non-zero (active) elements will be considered.
+ document_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the document embeddings to be considered in the FlopsLoss.
+ If specified, only document embeddings with more than this number of non-zero (active) elements will be considered.
Only used when document_regularizer is None (for the default FlopsLoss).
query_regularizer_threshold: Optional threshold for the number of non-zero (active) elements in the query embeddings to be considered in the FlopsLoss.
If specified, only query embeddings with more than this number of non-zero (active) elements will be considered.
@@ -149,10 +149,10 @@ def forward(
if self.use_document_regularizer_only:
# If use_document_regularizer_only is True, we consider all the input to be of the same type and so under the same regularization
- corpus_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings))
+ document_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings))
else:
- corpus_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings[1:]))
- losses["document_regularizer_loss"] = corpus_loss * self.document_regularizer_weight
+ document_loss = self.document_regularizer.compute_loss_from_embeddings(torch.cat(embeddings[1:]))
+ losses["document_regularizer_loss"] = document_loss * self.document_regularizer_weight
# Add query regularization if enabled
if self.query_regularizer_weight is not None:
diff --git a/sentence_transformers/sparse_encoder/losses/__init__.py b/sentence_transformers/sparse_encoder/losses/__init__.py
index d1ad8666c..174b67055 100644
--- a/sentence_transformers/sparse_encoder/losses/__init__.py
+++ b/sentence_transformers/sparse_encoder/losses/__init__.py
@@ -12,7 +12,10 @@
from .SparseTripletLoss import SparseTripletLoss
from .SpladeLoss import SpladeLoss
+from .CachedSpladeLoss import CachedSpladeLoss # isort: skip # Avoid circular import with SpladeLoss -> FlopsLoss
+
__all__ = [
+ "CachedSpladeLoss",
"CSRLoss",
"CSRReconstructionLoss",
"SparseMultipleNegativesRankingLoss",