Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 113 additions & 6 deletions sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import deque
from collections.abc import Iterable, Iterator
from contextlib import nullcontext
from functools import partial
Expand Down Expand Up @@ -75,6 +76,9 @@ def __init__(
mini_batch_size: int = 32,
gather_across_devices: bool = False,
show_progress_bar: bool = False,
use_cont_accum: bool = False,
cache_size: int | None = None,
prev_cache: bool = False,
) -> None:
"""
Boosted version of MultipleNegativesRankingLoss (https://huggingface.co/papers/1705.00652) by GradCache (https://huggingface.co/papers/2101.06983).
Expand Down Expand Up @@ -110,10 +114,17 @@ def __init__(
Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down
training due to communication overhead, and can potentially lead to out-of-memory errors.
show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False.
use_cont_accum: If True, enable Contrastive Accumulation (ContAccum) by caching embeddings from previous steps
and reusing them as in-batch negatives.
cache_size: Maximum number of cached embeddings per column. Set this to the effective batch size if you want
ContAccum to mimic gradient accumulation behavior.
prev_cache: If True, keep cached embeddings across optimizer steps. Otherwise, caches are cleared after each
optimizer step (requires Trainer integration or manual calls to ``on_optimizer_step``).

References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652
- Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://huggingface.co/papers/2101.06983
- A Gradient Accumulation Method for Dense Retriever under Memory Constraint: https://arxiv.org/abs/2406.12356

Requirements:
1. (anchor, positive) pairs or (anchor, positive, negative pairs)
Expand Down Expand Up @@ -176,6 +187,14 @@ def __init__(
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.cache: list[list[Tensor]] | None = None
self.random_states: list[list[RandContext]] | None = None
self.use_cont_accum = use_cont_accum
self.cache_size = 0 if cache_size is None else cache_size
self.prev_cache = prev_cache
self._query_cache: deque[Tensor] | None = None
self._candidate_caches: list[deque[Tensor]] | None = None

if self.use_cont_accum and self.cache_size <= 0:
raise ValueError("cache_size must be a positive integer when use_cont_accum is True.")

def embed_minibatch(
self,
Expand Down Expand Up @@ -229,20 +248,97 @@ def embed_minibatch_iter(
)
yield reps, random_state # reps: (mbsz, hdim)

def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor:
def _init_cont_caches(self, num_candidate_columns: int) -> None:
if not self.use_cont_accum or self._candidate_caches is not None:
return
self._candidate_caches = [deque(maxlen=self.cache_size) for _ in range(num_candidate_columns)]
self._query_cache = deque(maxlen=self.cache_size)

def _cached_tensor(self, cache: deque[Tensor] | None, device: torch.device) -> Tensor | None:
if cache is None or len(cache) == 0:
return None
return torch.cat(list(cache), dim=0).to(device)

def _collect_cont_cache(self, reps: list[list[Tensor]]) -> tuple[Tensor | None, list[Tensor | None] | None]:
if not self.use_cont_accum or not self.training:
return None, None
self._init_cont_caches(num_candidate_columns=len(reps) - 1)

anchors = torch.cat(reps[0])
cached_anchors = self._cached_tensor(self._query_cache, anchors.device)

cached_candidates: list[Tensor | None] = []
assert self._candidate_caches is not None
for reps_col, cache in zip(reps[1:], self._candidate_caches):
candidates = torch.cat(reps_col)
cached_candidates.append(self._cached_tensor(cache, candidates.device))

return cached_anchors, cached_candidates

def _enqueue_cont_cache(self, reps: list[list[Tensor]]) -> None:
if not self.use_cont_accum or not self.training:
return
if self._candidate_caches is None:
self._init_cont_caches(num_candidate_columns=len(reps) - 1)
assert self._candidate_caches is not None

anchors = torch.cat(reps[0])
self._query_cache.append(anchors.detach().clone())

for reps_col, cache in zip(reps[1:], self._candidate_caches):
candidates = torch.cat(reps_col)
cache.append(candidates.detach().clone())

def reset_cont_cache(self) -> None:
if self._query_cache is not None:
self._query_cache.clear()
if self._candidate_caches is not None:
for cache in self._candidate_caches:
cache.clear()

def on_optimizer_step(self) -> None:
if self.use_cont_accum and not self.prev_cache:
self.reset_cont_cache()

def calculate_loss_and_cache_gradients(
self,
reps: list[list[Tensor]],
cached_anchors: Tensor | None = None,
cached_candidates: list[Tensor | None] | None = None,
) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
loss = self.calculate_loss(reps, with_backward=True)
loss = self.calculate_loss(
reps,
cached_anchors=cached_anchors,
cached_candidates=cached_candidates,
with_backward=True,
)
loss = loss.detach().requires_grad_()

self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)

return loss

def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor:
def calculate_loss(
self,
reps: list[list[Tensor]],
cached_anchors: Tensor | None = None,
cached_candidates: list[Tensor | None] | None = None,
with_backward: bool = False,
) -> Tensor:
"""Calculate the cross-entropy loss. No need to cache the gradients."""
anchors = torch.cat(reps[0]) # (batch_size, embedding_dim)
if cached_anchors is not None:
anchors = torch.cat([anchors, cached_anchors], dim=0)

candidates = [torch.cat(r) for r in reps[1:]] # (1 + num_neg) tensors of shape (batch_size, embedding_dim)
batch_size = len(anchors)
if cached_candidates is not None:
for idx, cached in enumerate(cached_candidates):
if cached is None:
continue
candidates[idx] = torch.cat([candidates[idx], cached], dim=0)

batch_size = anchors.size(0)
offset = 0

if self.gather_across_devices:
Expand Down Expand Up @@ -300,15 +396,23 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
reps.append(reps_mbs)
self.random_states.append(random_state_mbs)

cached_anchors, cached_candidates = self._collect_cont_cache(reps)

if torch.is_grad_enabled():
# Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings
loss = self.calculate_loss_and_cache_gradients(reps)
loss = self.calculate_loss_and_cache_gradients(
reps,
cached_anchors=cached_anchors,
cached_candidates=cached_candidates,
)

# Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain
loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
else:
# If grad is not enabled (e.g. in evaluation), then we don't have to worry about the gradients or backward hook
loss = self.calculate_loss(reps)
loss = self.calculate_loss(reps, cached_anchors=cached_anchors, cached_candidates=cached_candidates)

self._enqueue_cont_cache(reps)

return loss

Expand All @@ -318,6 +422,9 @@ def get_config_dict(self) -> dict[str, Any]:
"similarity_fct": self.similarity_fct.__name__,
"mini_batch_size": self.mini_batch_size,
"gather_across_devices": self.gather_across_devices,
"use_cont_accum": self.use_cont_accum,
"cache_size": self.cache_size,
"prev_cache": self.prev_cache,
}

@property
Expand Down
16 changes: 16 additions & 0 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@
TrackioCallback = None


class _LossCacheResetCallback(TrainerCallback):
def __init__(self, trainer: SentenceTransformerTrainer) -> None:
self.trainer = trainer

def on_step_end(self, args, state, control, **kwargs):
self.trainer._loss_on_optimizer_step()


class SentenceTransformerTrainer(Trainer):
"""
SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch
Expand Down Expand Up @@ -312,6 +320,8 @@ def __init__(
else:
self.loss = self.prepare_loss(loss, model)

self.add_callback(_LossCacheResetCallback(self))

# If evaluator is a list, we wrap it in a SequentialEvaluator
if evaluator is not None and not isinstance(evaluator, SentenceEvaluator):
evaluator = SequentialEvaluator(evaluator)
Expand Down Expand Up @@ -443,6 +453,12 @@ def compute_loss(
return loss, {}
return loss

def _loss_on_optimizer_step(self) -> None:
losses = self.loss.values() if isinstance(self.loss, dict) else [self.loss]
for loss_fn in losses:
if hasattr(loss_fn, "on_optimizer_step"):
loss_fn.on_optimizer_step()

def track_loss_components(self, loss: dict[str, torch.Tensor]) -> None:
training_type = "train" if self.model.training else "eval"
for key, value in loss.items():
Expand Down
8 changes: 7 additions & 1 deletion tests/cross_encoder/test_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
import json
import logging
import re
Expand Down Expand Up @@ -267,6 +268,11 @@ def test_push_to_hub(
) -> None:
model = reranker_bert_tiny_model

def build_commit_info(**kwargs):
if "_endpoint" in inspect.signature(CommitInfo).parameters:
return CommitInfo("https://huggingface.co", **kwargs)
return CommitInfo(**kwargs)

def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

Expand All @@ -279,7 +285,7 @@ def mock_upload_folder(self, **kwargs):
revision = "123456"
else:
revision = "678901"
return CommitInfo(
return build_commit_info(
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{revision}",
commit_message="commit_message",
commit_description="commit_description",
Expand Down
10 changes: 8 additions & 2 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import copy
import inspect
import json
import logging
import os
Expand Down Expand Up @@ -112,6 +113,11 @@ def test_torch_dtype(torch_dtype) -> None:


def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def build_commit_info(**kwargs):
Copy link
Copy Markdown
Member

@tomaarsen tomaarsen Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the inconvenience here, this was caused by this: huggingface/huggingface_hub#3737

It's resolved and released now, so we're good either way.

(I'm going through the paper etc. now)

if "_endpoint" in inspect.signature(CommitInfo).parameters:
return CommitInfo("https://huggingface.co", **kwargs)
return CommitInfo(**kwargs)

def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

Expand All @@ -121,14 +127,14 @@ def mock_upload_folder(self, **kwargs):
nonlocal mock_upload_folder_kwargs
mock_upload_folder_kwargs = kwargs
if kwargs.get("revision") is None:
return CommitInfo(
return build_commit_info(
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/123456",
commit_message="commit_message",
commit_description="commit_description",
oid="oid",
)
else:
return CommitInfo(
return build_commit_info(
commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/678901",
commit_message="commit_message",
commit_description="commit_description",
Expand Down