Skip to content

[feat] Introduce CachedSpladeLoss for memory-efficient SPLADE training#3670

Merged
tomaarsen merged 14 commits intohuggingface:mainfrom
yjoonjang:feat/add_cached_sparse
Feb 24, 2026
Merged

[feat] Introduce CachedSpladeLoss for memory-efficient SPLADE training#3670
tomaarsen merged 14 commits intohuggingface:mainfrom
yjoonjang:feat/add_cached_sparse

Conversation

@yjoonjang
Copy link
Copy Markdown
Contributor

Hi, @tomaarsen !

Overview

This PR introduces CachedSpladeLoss, a gradient-cached version of SpladeLoss that enables training SPLADE models with larger batch sizes without additional GPU memory. This resolves the long-standing architectural conflict between gradient caching and SpladeLoss's embedding-sharing design.

Problem

As discussed in #3451, a fundamental architectural conflict existed in implementing gradient cache to SparseMultipleNegativesRankingLoss:

  • Cached losses need raw inputs to compute embeddings in mini-batches
  • SpladeLoss computes embeddings at its own level and shares them with both the base loss and the FLOPS regularizer

Simply wrapping SparseMultipleNegativesRankingLoss with the caching pattern from CachedMultipleNegativesRankingLoss doesn't work — SpladeLoss needs to control the embedding computation so it can route the same embeddings to both the base loss and the regularizers.

Solution

Caching should happen at the SpladeLoss (wrapper) level, not at the base loss level.

CachedSpladeLoss inherits from SpladeLoss and overrides forward() with the GradCache 3-step pattern:

  1. Embed without gradients: process all sentences in mini-batches without computation graphs, saving random states for reproducibility
  2. Compute loss and cache gradients: calculate the combined loss (base loss + document/query regularizers) on the full concatenated embeddings, backward to embeddings, and cache the gradients
  3. Re-embed with gradients: re-process mini-batches with computation graphs enabled, replaying saved random states, and chain cached gradients via backward hook

Since both the base loss and regularizers still receive pre-computed embeddings via compute_loss_from_embeddings(), no changes to existing base losses or regularizers are needed.

Results

All results are fully reproducible on Google Colab (free T4 GPU). See the demo notebook.

VRAM Comparison

GPU: Tesla T4 (15 GB), Model: distilbert/distilbert-base-uncased, seq_len=256, fp16

SpladeLoss CachedSpladeLoss
Batch Size 256 1024 (mini_batch=256)
Peak VRAM 5,948 MB 3,776 MB
VRAM Savings baseline -36.5%
  • SpladeLoss at batch_size=1024 → OOM
  • CachedSpladeLoss at batch_size=1024 → 3,776 MB (works easily)
  • Result: 4x larger effective batch size with 36.5% less VRAM on a free Colab T4

Training Quality

AllNLI 200k triplets, 1 epoch, distilbert-base-uncased, NanoBEIR evaluation:

Metric SpladeLoss (bs=64) CachedSpladeLoss (bs=1024)
NanoBEIR NDCG@10 0.2733 0.2566
NanoBEIR MRR@10 0.3233 0.3049
NanoBEIR MAP@100 0.2172 0.2059
Peak VRAM 8,986 MB 5,259 MB (-41.5%)

CachedSpladeLoss achieves comparable retrieval quality (NDCG@10 diff = 0.017) while using a 16x larger effective batch size and 41.5% less VRAM. The small difference comes from the different number of in-batch negatives (64 vs 1024), not from any gradient approximation — the gradient caching is mathematically exact.

Usage

from sentence_transformers.sparse_encoder import SparseEncoder, SparseEncoderTrainer, losses
from sentence_transformers import SparseEncoderTrainingArguments

model = SparseEncoder("distilbert/distilbert-base-uncased")

loss = losses.CachedSpladeLoss(
    model=model,
    loss=losses.SparseMultipleNegativesRankingLoss(model),
    document_regularizer_weight=3e-5,
    query_regularizer_weight=5e-5,
    mini_batch_size=32,  # Controls actual GPU memory usage
)

trainer = SparseEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
    args=SparseEncoderTrainingArguments(
        output_dir="models/cached-splade",
        per_device_train_batch_size=1024,  # Large effective batch size!
        fp16=True,
    ),
)
trainer.train()

Key Parameters

Parameter Description Recommendation
mini_batch_size Actual GPU memory usage per forward pass Set as high as GPU allows (e.g., 32–256)
per_device_train_batch_size Effective batch size for contrastive learning Set much larger (e.g., 256, 512, 1024)
show_progress_bar Show mini-batch progress during training False (default)

Implementation Details

  1. New file: sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py

    • Inherits from SpladeLoss — automatically compatible with SpladeRegularizerWeightSchedulerCallback and all existing SpladeLoss infrastructure (isinstance checks pass)
    • Reuses RandContext and _backward_hook from CachedMultipleNegativesRankingLoss
    • Returns dict {"base_loss": ..., "document_regularizer_loss": ..., "query_regularizer_loss": ...} for loss component logging
    • Gradient flow: result["base_loss"] = cached_loss - detached_regularizer_sum, ensuring trainer's sum(values) equals the exact cached loss
  2. Modified file: sentence_transformers/sparse_encoder/losses/__init__.py

    • Added CachedSpladeLoss import (placed after SpladeLoss to avoid circular imports)
    • Added to __all__

References

@tomaarsen
Copy link
Copy Markdown
Member

tomaarsen commented Feb 20, 2026

Hello!

I'm running some tests, and it seems to work well!

Training Script
"""
This example trains a SparseEncoder for the Natural Questions (NQ) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets,
and trains the model as a retriever. After training, the model is evaluated and saved locally,
with an optional step to push the trained model to the Hugging Face Hub.

Usage:
python train_splade_nq.py
"""

import logging
import traceback

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "distilbert/distilbert-base-uncased"
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    global_batch_size = 16

    # 1a. Load a model to finetune with 1b. (Optional) model card data
    model = SparseEncoder(
        model_name,
        model_card_data=SparseEncoderModelCardData(
            language="en",
            license="apache-2.0",
            model_name=f"splade-{short_model_name} trained on Natural Questions",
        ),
    )
    model.max_seq_length = 256  # Set the max sequence length to 256 for the training
    logging.info("Model max length: %s", model.max_seq_length)

    # 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
    logging.info("Read the Natural Questions training dataset")
    full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 3. Define our training loss.
    query_regularizer_weight = 5e-5
    document_regularizer_weight = 3e-5

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        query_regularizer_weight=query_regularizer_weight,  # Weight for query loss
        document_regularizer_weight=document_regularizer_weight,  # Weight for document loss
    )

    # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
    evaluator = evaluation.SparseNanoBEIREvaluator(
        dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=global_batch_size
    )
    evaluator(model)

    # 5. Define the training arguments
    run_name = f"splade-{short_model_name}-nq-16bs-2e-6"
    training_args = SparseEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=global_batch_size,
        per_device_eval_batch_size=global_batch_size,
        warmup_ratio=0.1,
        learning_rate=2e-6,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # load_best_model_at_end=True,
        # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=0.2,
        save_strategy="steps",
        save_steps=0.2,
        save_total_limit=2,
        logging_steps=0.05,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. Create the trainer & start training
    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model again
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()
"""
This example trains a SparseEncoder for the Natural Questions (NQ) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets,
and trains the model as a retriever. After training, the model is evaluated and saved locally,
with an optional step to push the trained model to the Hugging Face Hub.

Usage:
python train_splade_nq.py
"""

import logging
import traceback

from datasets import load_dataset

from sentence_transformers import (
    SparseEncoder,
    SparseEncoderModelCardData,
    SparseEncoderTrainer,
    SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "distilbert/distilbert-base-uncased"
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    global_batch_size = 512
    mini_batch_size = 32

    # 1a. Load a model to finetune with 1b. (Optional) model card data
    model = SparseEncoder(
        model_name,
        model_card_data=SparseEncoderModelCardData(
            language="en",
            license="apache-2.0",
            model_name=f"splade-{short_model_name} trained on Natural Questions",
        ),
    )
    model.max_seq_length = 256  # Set the max sequence length to 256 for the training
    logging.info("Model max length: %s", model.max_seq_length)

    # 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
    logging.info("Read the Natural Questions training dataset")
    full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 3. Define our training loss.
    query_regularizer_weight = 5e-5
    document_regularizer_weight = 3e-5

    loss = losses.CachedSpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        mini_batch_size=mini_batch_size,
        query_regularizer_weight=query_regularizer_weight,  # Weight for query loss
        document_regularizer_weight=document_regularizer_weight,  # Weight for document loss
    )

    # 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
    evaluator = evaluation.SparseNanoBEIREvaluator(
        dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size
    )
    evaluator(model)

    # 5. Define the training arguments
    run_name = f"splade-{short_model_name}-nq-512bs-1e-5-32mbs"
    training_args = SparseEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=global_batch_size,
        per_device_eval_batch_size=global_batch_size,
        warmup_ratio=0.1,
        learning_rate=1e-5,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        # load_best_model_at_end=True,
        # metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=0.2,
        save_strategy="steps",
        save_steps=0.2,
        save_total_limit=2,
        logging_steps=0.05,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. Create the trainer & start training
    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model again
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()

(Granted: the larger batch size also means a lower relative contribution of the regularization losses, which means more active dimensions and higher maximum possible performance).

I'll look at the code more in detail soon.

  • Tom Aarsen

@tomaarsen
Copy link
Copy Markdown
Member

I pulled some fixes re. the CI from main and added a training script using the CachedSpladeLoss, and added some more documentation. This is already looking solid, I think.

If you'd like, could you perhaps try and copy train_splade_msmarco_margin_mse.py and use it to create train_splade_msmarco_margin_mse_cached.py, like I've done for train_splade_nq_cached.py. It would be a good test as well.

  • Tom Aarsen

@tomaarsen
Copy link
Copy Markdown
Member

Hmm, it looks like there's some failure due to the CachedSpladeLoss import (order?):

=========================== short test summary info ============================
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[splade_bert_tiny_model-0-expected_substrings0] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[splade_bert_tiny_model-1-expected_substrings1] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[splade_bert_tiny_model-2-expected_substrings2] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[splade_bert_tiny_model-10-expected_substrings3] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[splade_bert_tiny_model-50-expected_substrings4] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[csr_bert_tiny_model-0-expected_substrings5] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_model_card.py::test_model_card_base[inference_free_splade_bert_tiny_model-0-expected_substrings6] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_trainer.py::test_model_card_reuse - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_trainer.py::test_trainer[False] - TypeError: 'module' object is not callable
FAILED tests/sparse_encoder/test_trainer.py::test_trainer[True] - TypeError: 'module' object is not callable

The FlopsLoss is now imported as a module instead of a class. I'd really like to change all of the filenames in #3554 to prevent this from being possible, but for now I think we have to change the import order around to be safe (otherwise we'll accidentally break backwards compatibility for people training with FlopsLoss).

  • Tom Aarsen

@yjoonjang
Copy link
Copy Markdown
Contributor Author

Thanks for testing and for the great results !
The cached version with bs=512 outperforming the baseline (0.4581 vs 0.4219) is really encouraging.

Regarding the import issue — yes, I originally had the CachedSpladeLoss import placed after SpladeLoss in __init__.py to avoid this exact circular import problem.
The alphabetical reordering from the CI fixes broke that. I fixed the import order.

And I created train_splade_msmarco_margin_mse_cached.py based on the MarginMSE script.
I could not run the experiment since I cannot use any gpus right now.

@tomaarsen
Copy link
Copy Markdown
Member

Thanks! I'll push something to satisfy the CI while allowing for the new import order. I'll run tests with the train_splade_msmarco_margin_mse_cached.py, although I did just realise that the SparseMarginMSELoss doesn't benefit from larger batches, so using CachedSpladeLoss doesn't really make sense for it. Sorry about that! I'll remove the file once I've verified that it at least does work!

  • Tom Aarsen

@tomaarsen
Copy link
Copy Markdown
Member

I can confirm that it works:

Epoch Step Training Loss Validation Loss NanoMSMARCO_dot_ndcg@10 NanoNFCorpus_dot_ndcg@10 NanoNQ_dot_ndcg@10 NanoBEIR_mean_dot_ndcg@10
-1 -1 - - 0.0823 0.0441 0.0645 0.0636
0.0511 9 944862.5556 - - - - -
0.1023 18 70073.7569 - - - - -
0.1534 27 775.0326 - - - - -
0.2045 36 166.0733 113.0185 0.0684 0.0822 0.1578 0.1028
0.2557 45 105.6319 - - - - -
0.3068 54 75.4794 - - - - -
0.3580 63 57.1591 - - - - -
0.4091 72 46.1204 39.5340 0.4624 0.2468 0.5516 0.4203
0.4602 81 40.1958 - - - - -
0.5114 90 37.5867 - - - - -
0.5625 99 35.1976 - - - - -
0.6136 108 32.7673 30.2409 0.5413 0.2771 0.5781 0.4655

I added a isort: skip to satisfy the CI, and removed the cached training file again. Apologies again.

I think this is nearing completion, I'll let Copilot review and have another look myself too.

  • Tom Aarsen

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces CachedSpladeLoss, a memory-efficient variant of SpladeLoss that enables training SPLADE models with significantly larger batch sizes using the GradCache technique. This addresses the architectural conflict between gradient caching and SpladeLoss's embedding-sharing design, allowing for 4-16x larger effective batch sizes with 36-41% less GPU memory usage.

Changes:

  • Implements CachedSpladeLoss that performs gradient caching at the wrapper level (SpladeLoss) rather than the base loss level
  • Adds comprehensive documentation and examples for the new loss function
  • Updates existing training examples with improved hyperparameters

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py New loss class implementing GradCache pattern for SPLADE training
sentence_transformers/sparse_encoder/losses/__init__.py Exports the new CachedSpladeLoss class
examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py Example script demonstrating CachedSpladeLoss with large batch sizes
examples/sparse_encoder/training/retrievers/train_splade_nq.py Updated hyperparameters and variable naming consistency
examples/sparse_encoder/training/retrievers/README.md Documentation for the new cached training example
docs/sparse_encoder/loss_overview.md Documentation updates explaining CachedSpladeLoss and its benefits
docs/package_reference/sparse_encoder/losses.md API reference documentation for CachedSpladeLoss

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@yjoonjang
Copy link
Copy Markdown
Contributor Author

Thank you for all your experiments and review !
I checked all the commits and it looks great.

@tomaarsen tomaarsen enabled auto-merge (squash) February 24, 2026 16:27
@tomaarsen
Copy link
Copy Markdown
Member

Should be ready once the tests pass then! Thanks for opening this, it was something that the project had been lacking for a while, and it's nice to see it added. Great work 👏

  • Tom Aarsen

@yjoonjang
Copy link
Copy Markdown
Contributor Author

Thank you so much for the review and your kind words !
I'm really glad to contribute to sparse model training. 🚀

Also, I know you have a lot on your plate, but if you have a spare moment later, could you take a quick look at another PR I opened (#3665) ? It's about EmbedDistill, and I really like this training method. But no rush at all !!

Thanks again for your time and help.

@tomaarsen
Copy link
Copy Markdown
Member

I also saw that one, there's just been a lot of PRs to go through, and EmbedDistil was the only PR whose paper I haven't yet read. I'll definitely get to it. And I'll revisit the GIST self-guide as well. They should be quite promising!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 73d700b into huggingface:main Feb 24, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants