Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/package_reference/sparse_encoder/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ Sadly, there is no "one size fits all" loss function. Which loss function is sui
.. autoclass:: sentence_transformers.sparse_encoder.losses.SpladeLoss
```

## CachedSpladeLoss
```{eval-rst}
.. autoclass:: sentence_transformers.sparse_encoder.losses.CachedSpladeLoss
```

## FlopsLoss
```{eval-rst}
.. autoclass:: sentence_transformers.sparse_encoder.losses.FlopsLoss
Expand Down
38 changes: 20 additions & 18 deletions docs/sparse_encoder/loss_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,33 @@

```{eval-rst}
.. warning::
To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding.
To train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`, you need either :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss`, :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss`, or :class:`~sentence_transformers.sparse_encoder.losses.CSRLoss`, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is :class:`~sentence_transformers.sparse_encoder.losses.SparseMSELoss`, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher's sparse embedding.

```

## Sparse specific Loss Functions

### SPLADE Loss

The <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a> implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to control efficiency:
The <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a> implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to balance effectiveness and efficiency:

- Supports all the losses mention below as main loss but three principal loss types: <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a>, <a href="../package_reference/sparse_encoder/losses.html#sparsemarginmseloss"><code>SparseMarginMSELoss</code></a> and <a href="../package_reference/sparse_encoder/losses.html#sparsedistillkldivloss"><code>SparseDistillKLDivLoss</code></a>.
- Uses <a href="../package_reference/sparse_encoder/losses.html#flopsloss"><code>FlopsLoss</code></a> for regularization to control sparsity by default, but supports custom regularizers.
- Balances effectiveness (via the main loss) with efficiency by regularizing both query and document representations.
- Allows using different regularizers for queries and documents via the `query_regularizer` and `document_regularizer` parameters, enabling fine-grained control over sparsity patterns for different types of inputs.
- Supports separate threshold values for queries and documents via the `query_regularizer_threshold` and `document_regularizer_threshold` parameters, allowing different sparsity strictness levels for each input type.
1. Main loss: Supports all the losses from the <a href="#loss-table">Loss Table</a> and <a href="#distillation">Distillation</a>, with <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a>, <a href="../package_reference/sparse_encoder/losses.html#sparsemarginmseloss"><code>SparseMarginMSELoss</code></a> and <a href="../package_reference/sparse_encoder/losses.html#sparsedistillkldivloss"><code>SparseDistillKLDivLoss</code></a> commonly used.
2. Regularization loss: <a href="../package_reference/sparse_encoder/losses.html#flopsloss"><code>FlopsLoss</code></a> is used to control sparsity, but supports custom regularizers.
- `query_regularizer` and `document_regularizer` can be set to any custom regularization loss.
- `query_regularizer_threshold` and `document_regularizer_threshold` can be set to control the sparsity strictness for queries and documents separately, setting the regularization loss to zero if an embedding has less than the threshold number of active (non-zero) dimensions.

#### Cached SPLADE Loss

The <a href="../package_reference/sparse_encoder/losses.html#cachedspladeloss"><code>CachedSpladeLoss</code></a> is a variant of the SPLADE loss adopting <a href="https://huggingface.co/papers/2101.06983">GradCache</a>, which allows for much larger batch sizes without additional GPU memory usage. It achieves this by computing and caching loss gradients in mini-batches.

Main losses that use in-batch negatives, primarily <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a>, benefit heavily from larger batch sizes, as it results in more negatives and a stronger training signal.

### CSR Loss

If you are using the <a href="../package_reference/sparse_encoder/models.html#sparseautoencoder"><code>SparseAutoEncoder</code></a> module, then you have to use the <a href="../package_reference/sparse_encoder/losses.html#csrloss"><code>CSRLoss</code></a> (Contrastive Sparse Representation Loss). It combines two components:

- A reconstruction loss <a href="../package_reference/sparse_encoder/losses.html#csrreconstructionloss"><code>CSRReconstructionLoss</code></a> that ensures sparse representation can faithfully reconstruct original embeddings.
- A main loss, which in the paper is a contrastive learning component using <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a> that ensures semanticallysimilar sentences have similar representations. But it's theoretically possible to use all the losses mention below as main loss like for <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a> .

1. Main loss: Supports all the losses from the <a href="#loss-table">Loss Table</a> and <a href="#distillation">Distillation</a>, with <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a> used in the CSR Paper.
2. Reconstruction loss: <a href="../package_reference/sparse_encoder/losses.html#csrreconstructionloss"><code>CSRReconstructionLoss</code></a> is used to ensure that sparse representation can faithfully reconstruct the original dense embeddings.

## Loss Table

Expand All @@ -34,18 +38,16 @@ Loss functions play a critical role in the performance of your fine-tuned model.
.. note::

You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, ``(sentence_A, sentence_B) pairs`` with ``class`` labels can be converted into ``(anchor, positive, negative) triplets`` by sampling sentences with the same or different classes.

.. note::

The loss functions in `SentenceTransformer > Loss Overview <../sentence_transformer/loss_overview.html>`_ that appear here with the ``Sparse`` prefix are identical to their dense versions. The prefix is used only to indicate which losses can be used as main losses to train a :class:`~sentence_transformers.sparse_encoder.SparseEncoder`
```

**Legend:** Loss functions marked with `★` are commonly recommended default choices.

| Inputs | Labels | Appropriate Loss Functions |
|---------------------------------------------------|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `(anchor, positive) pairs` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a> |
| `(anchor, positive) pairs` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a> |
| `(sentence_A, sentence_B) pairs` | `float similarity score between 0 and 1` | <a href="../package_reference/sparse_encoder/losses.html#sparsecosentloss">`SparseCoSENTLoss`</a><br><a href="../package_reference/sparse_encoder/losses.html#sparseangleloss">`SparseAnglELoss`</a><br><a href="../package_reference/sparse_encoder/losses.html#sparsecosinesimilarityloss">`SparseCosineSimilarityLoss`</a> |
| `(anchor, positive, negative) triplets` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a><br><a href="../package_reference/sparse_encoder/losses.html#sparsetripletloss">`SparseTripletLoss`</a> |
| `(anchor, positive, negative_1, ..., negative_n)` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a> |
| `(anchor, positive, negative) triplets` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a><br><a href="../package_reference/sparse_encoder/losses.html#sparsetripletloss">`SparseTripletLoss`</a> |
| `(anchor, positive, negative_1, ..., negative_n)` | `none` | <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss">`SparseMultipleNegativesRankingLoss`</a> |


## Distillation
Expand All @@ -64,7 +66,7 @@ These loss functions are specifically designed to be used when distilling the kn

In practice, not all loss functions get used equally often. The most common scenarios are:

* `(anchor, positive) pairs` without any labels: <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a> (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a> or <a href="../package_reference/sparse_encoder/losses.html#csrloss"><code>CSRLoss</code></a>, both typically using InfoNCE as their underlying loss function.
* `(anchor, positive) pairs` without any labels: <a href="../package_reference/sparse_encoder/losses.html#sparsemultiplenegativesrankingloss"><code>SparseMultipleNegativesRankingLoss</code></a> (a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well with <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a>, <a href="../package_reference/sparse_encoder/losses.html#cachedspladeloss"><code>CachedSpladeLoss</code></a>, or <a href="../package_reference/sparse_encoder/losses.html#csrloss"><code>CSRLoss</code></a>, both typically using InfoNCE as their underlying loss function.

* `(query, positive, negative_1, ..., negative_n)` format: This structure with multiple negatives is particularly effective with <a href="../package_reference/sparse_encoder/losses.html#spladeloss"><code>SpladeLoss</code></a> configured with <a href="../package_reference/sparse_encoder/losses.html#sparsemarginmseloss"><code>SparseMarginMSELoss</code></a>, especially in knowledge distillation scenarios where a teacher model provides similarity scores. The strongest models are trained with distillation losses like <a href="../package_reference/sparse_encoder/losses.html#sparsedistillkldivloss"><code>SparseDistillKLDivLoss</code></a> or <a href="../package_reference/sparse_encoder/losses.html#sparsemarginmseloss"><code>SparseMarginMSELoss</code></a>.

Expand Down
6 changes: 6 additions & 0 deletions examples/sparse_encoder/training/retrievers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Example scripts could be:
This example also uses :class:`~sentence_transformers.sparse_encoder.losses.SpladeLoss` (similarly utilizing :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss`) and trains on the `NQ (natural questions) <https://huggingface.co/datasets/sentence-transformers/natural-questions>`_ dataset. It showcases an alternative configuration or approach for training SPLADE models on question-answering data for sparse retrieval.
```

- **[train_splade_nq_cached.py](train_splade_nq_cached.py)**:

```{eval-rst}
This example is similar to the last one, but uses :class:`~sentence_transformers.sparse_encoder.losses.CachedSpladeLoss` to get much larger batch sizes (e.g. 512 instead of 16) during training without increasing GPU memory usage. Because :class:`~sentence_transformers.sparse_encoder.losses.SparseMultipleNegativesRankingLoss` benefits greatly from larger batch sizes (more in-batch negatives), this results in better retrieval performance.
```

- **[train_csr_nq.py](train_csr_nq.py)**:

```{eval-rst}
Expand Down
37 changes: 18 additions & 19 deletions examples/sparse_encoder/training/retrievers/train_splade_nq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@

def main():
model_name = "distilbert/distilbert-base-uncased"

train_batch_size = 12
num_epochs = 1
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
global_batch_size = 16

# 1a. Load a model to finetune with 1b. (Optional) model card data
model = SparseEncoder(
model_name,
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="splade-distilbert-base-uncased trained on Natural Questions",
model_name=f"splade-{short_model_name} trained on Natural Questions",
),
)
model.max_seq_length = 256 # Set the max sequence length to 256 for the training
Expand Down Expand Up @@ -67,34 +66,35 @@ def main():

# 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
evaluator = evaluation.SparseNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=train_batch_size
dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=global_batch_size
)
evaluator(model)

# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"splade-{short_model_name}-nq"
training_args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
num_train_epochs=1,
per_device_train_batch_size=global_batch_size,
per_device_eval_batch_size=global_batch_size,
warmup_ratio=0.1,
learning_rate=2e-6,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
load_best_model_at_end=True,
metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1650,
eval_steps=0.2,
save_strategy="steps",
save_steps=1650,
save_steps=0.2,
save_total_limit=2,
logging_steps=200,
logging_steps=0.05,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=42,
# Uncomment the following lines to enable loading the best model at the end of training based on evaluation performance
# load_best_model_at_end=True,
# metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
)

# 6. Create the trainer & start training
Expand All @@ -108,9 +108,8 @@ def main():
)
trainer.train()

# 7. Evaluate the final model, using the complete NanoBEIR dataset
test_evaluator = evaluation.SparseNanoBEIREvaluator(show_progress_bar=True, batch_size=train_batch_size)
test_evaluator(model)
# 7. Evaluate the final model again
evaluator(model)

# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
Expand Down
Loading