[feat] Introduce CachedSpladeLoss for memory-efficient SPLADE training#3670
[feat] Introduce CachedSpladeLoss for memory-efficient SPLADE training#3670tomaarsen merged 14 commits intohuggingface:mainfrom
feat] Introduce CachedSpladeLoss for memory-efficient SPLADE training#3670Conversation
|
Hello! I'm running some tests, and it seems to work well! Training Script"""
This example trains a SparseEncoder for the Natural Questions (NQ) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets,
and trains the model as a retriever. After training, the model is evaluated and saved locally,
with an optional step to push the trained model to the Hugging Face Hub.
Usage:
python train_splade_nq.py
"""
import logging
import traceback
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "distilbert/distilbert-base-uncased"
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
global_batch_size = 16
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = SparseEncoder(
model_name,
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name=f"splade-{short_model_name} trained on Natural Questions",
),
)
model.max_seq_length = 256 # Set the max sequence length to 256 for the training
logging.info("Model max length: %s", model.max_seq_length)
# 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
logging.info("Read the Natural Questions training dataset")
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 3. Define our training loss.
query_regularizer_weight = 5e-5
document_regularizer_weight = 3e-5
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=query_regularizer_weight, # Weight for query loss
document_regularizer_weight=document_regularizer_weight, # Weight for document loss
)
# 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
evaluator = evaluation.SparseNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=global_batch_size
)
evaluator(model)
# 5. Define the training arguments
run_name = f"splade-{short_model_name}-nq-16bs-2e-6"
training_args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=global_batch_size,
per_device_eval_batch_size=global_batch_size,
warmup_ratio=0.1,
learning_rate=2e-6,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# load_best_model_at_end=True,
# metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.2,
save_strategy="steps",
save_steps=0.2,
save_total_limit=2,
logging_steps=0.05,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. Create the trainer & start training
trainer = SparseEncoderTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model again
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()"""
This example trains a SparseEncoder for the Natural Questions (NQ) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the Natural Questions dataset, splits it into training and evaluation subsets,
and trains the model as a retriever. After training, the model is evaluated and saved locally,
with an optional step to push the trained model to the Hugging Face Hub.
Usage:
python train_splade_nq.py
"""
import logging
import traceback
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "distilbert/distilbert-base-uncased"
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
global_batch_size = 512
mini_batch_size = 32
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = SparseEncoder(
model_name,
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name=f"splade-{short_model_name} trained on Natural Questions",
),
)
model.max_seq_length = 256 # Set the max sequence length to 256 for the training
logging.info("Model max length: %s", model.max_seq_length)
# 2. Load the NQ dataset: https://huggingface.co/datasets/sentence-transformers/natural-questions
logging.info("Read the Natural Questions training dataset")
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 3. Define our training loss.
query_regularizer_weight = 5e-5
document_regularizer_weight = 3e-5
loss = losses.CachedSpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
mini_batch_size=mini_batch_size,
query_regularizer_weight=query_regularizer_weight, # Weight for query loss
document_regularizer_weight=document_regularizer_weight, # Weight for document loss
)
# 4. Define evaluator. We use the SparseNanoBEIREvaluator, which is a light-weight evaluator
evaluator = evaluation.SparseNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"], show_progress_bar=True, batch_size=mini_batch_size
)
evaluator(model)
# 5. Define the training arguments
run_name = f"splade-{short_model_name}-nq-512bs-1e-5-32mbs"
training_args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=global_batch_size,
per_device_eval_batch_size=global_batch_size,
warmup_ratio=0.1,
learning_rate=1e-5,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# load_best_model_at_end=True,
# metric_for_best_model="eval_NanoBEIR_mean_dot_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=0.2,
save_strategy="steps",
save_steps=0.2,
save_total_limit=2,
logging_steps=0.05,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. Create the trainer & start training
trainer = SparseEncoderTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model again
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = SparseEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()
(Granted: the larger batch size also means a lower relative contribution of the regularization losses, which means more active dimensions and higher maximum possible performance). I'll look at the code more in detail soon.
|
|
I pulled some fixes re. the CI from If you'd like, could you perhaps try and copy
|
|
Hmm, it looks like there's some failure due to the CachedSpladeLoss import (order?): The
|
|
Thanks for testing and for the great results ! Regarding the import issue — yes, I originally had the CachedSpladeLoss import placed after SpladeLoss in And I created |
|
Thanks! I'll push something to satisfy the CI while allowing for the new import order. I'll run tests with the
|
|
I can confirm that it works:
I added a I think this is nearing completion, I'll let Copilot review and have another look myself too.
|
There was a problem hiding this comment.
Pull request overview
This PR introduces CachedSpladeLoss, a memory-efficient variant of SpladeLoss that enables training SPLADE models with significantly larger batch sizes using the GradCache technique. This addresses the architectural conflict between gradient caching and SpladeLoss's embedding-sharing design, allowing for 4-16x larger effective batch sizes with 36-41% less GPU memory usage.
Changes:
- Implements
CachedSpladeLossthat performs gradient caching at the wrapper level (SpladeLoss) rather than the base loss level - Adds comprehensive documentation and examples for the new loss function
- Updates existing training examples with improved hyperparameters
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py |
New loss class implementing GradCache pattern for SPLADE training |
sentence_transformers/sparse_encoder/losses/__init__.py |
Exports the new CachedSpladeLoss class |
examples/sparse_encoder/training/retrievers/train_splade_nq_cached.py |
Example script demonstrating CachedSpladeLoss with large batch sizes |
examples/sparse_encoder/training/retrievers/train_splade_nq.py |
Updated hyperparameters and variable naming consistency |
examples/sparse_encoder/training/retrievers/README.md |
Documentation for the new cached training example |
docs/sparse_encoder/loss_overview.md |
Documentation updates explaining CachedSpladeLoss and its benefits |
docs/package_reference/sparse_encoder/losses.md |
API reference documentation for CachedSpladeLoss |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Thank you for all your experiments and review ! |
|
Should be ready once the tests pass then! Thanks for opening this, it was something that the project had been lacking for a while, and it's nice to see it added. Great work 👏
|
|
Thank you so much for the review and your kind words ! Also, I know you have a lot on your plate, but if you have a spare moment later, could you take a quick look at another PR I opened (#3665) ? It's about EmbedDistill, and I really like this training method. But no rush at all !! Thanks again for your time and help. |
|
I also saw that one, there's just been a lot of PRs to go through, and EmbedDistil was the only PR whose paper I haven't yet read. I'll definitely get to it. And I'll revisit the GIST self-guide as well. They should be quite promising!
|
Hi, @tomaarsen !
Overview
This PR introduces
CachedSpladeLoss, a gradient-cached version ofSpladeLossthat enables training SPLADE models with larger batch sizes without additional GPU memory. This resolves the long-standing architectural conflict between gradient caching and SpladeLoss's embedding-sharing design.Problem
As discussed in #3451, a fundamental architectural conflict existed in implementing gradient cache to
SparseMultipleNegativesRankingLoss:Simply wrapping
SparseMultipleNegativesRankingLosswith the caching pattern fromCachedMultipleNegativesRankingLossdoesn't work — SpladeLoss needs to control the embedding computation so it can route the same embeddings to both the base loss and the regularizers.Solution
Caching should happen at the SpladeLoss (wrapper) level, not at the base loss level.
CachedSpladeLossinherits fromSpladeLossand overridesforward()with the GradCache 3-step pattern:Since both the base loss and regularizers still receive pre-computed embeddings via
compute_loss_from_embeddings(), no changes to existing base losses or regularizers are needed.Results
All results are fully reproducible on Google Colab (free T4 GPU). See the demo notebook.
VRAM Comparison
GPU: Tesla T4 (15 GB), Model:
distilbert/distilbert-base-uncased, seq_len=256, fp16Training Quality
AllNLI 200k triplets, 1 epoch,
distilbert-base-uncased, NanoBEIR evaluation:CachedSpladeLoss achieves comparable retrieval quality (NDCG@10 diff = 0.017) while using a 16x larger effective batch size and 41.5% less VRAM. The small difference comes from the different number of in-batch negatives (64 vs 1024), not from any gradient approximation — the gradient caching is mathematically exact.
Usage
Key Parameters
mini_batch_sizeper_device_train_batch_sizeshow_progress_barFalse(default)Implementation Details
New file:
sentence_transformers/sparse_encoder/losses/CachedSpladeLoss.pySpladeLoss— automatically compatible withSpladeRegularizerWeightSchedulerCallbackand all existing SpladeLoss infrastructure (isinstance checks pass)RandContextand_backward_hookfromCachedMultipleNegativesRankingLoss{"base_loss": ..., "document_regularizer_loss": ..., "query_regularizer_loss": ...}for loss component loggingresult["base_loss"] = cached_loss - detached_regularizer_sum, ensuring trainer'ssum(values)equals the exact cached lossModified file:
sentence_transformers/sparse_encoder/losses/__init__.pyCachedSpladeLossimport (placed afterSpladeLossto avoid circular imports)__all__References