diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index 3245def41..8919cfeb7 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import deque from collections.abc import Iterable, Iterator from contextlib import nullcontext from functools import partial @@ -75,6 +76,9 @@ def __init__( mini_batch_size: int = 32, gather_across_devices: bool = False, show_progress_bar: bool = False, + use_cont_accum: bool = False, + cache_size: int | None = None, + prev_cache: bool = False, ) -> None: """ Boosted version of MultipleNegativesRankingLoss (https://huggingface.co/papers/1705.00652) by GradCache (https://huggingface.co/papers/2101.06983). @@ -110,10 +114,17 @@ def __init__( Recommended when training on multiple GPUs, as it allows for larger batch sizes, but it may slow down training due to communication overhead, and can potentially lead to out-of-memory errors. show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False. + use_cont_accum: If True, enable Contrastive Accumulation (ContAccum) by caching embeddings from previous steps + and reusing them as in-batch negatives. + cache_size: Maximum number of cached embeddings per column. Set this to the effective batch size if you want + ContAccum to mimic gradient accumulation behavior. + prev_cache: If True, keep cached embeddings across optimizer steps. Otherwise, caches are cleared after each + optimizer step (requires Trainer integration or manual calls to ``on_optimizer_step``). References: - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652 - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://huggingface.co/papers/2101.06983 + - A Gradient Accumulation Method for Dense Retriever under Memory Constraint: https://arxiv.org/abs/2406.12356 Requirements: 1. (anchor, positive) pairs or (anchor, positive, negative pairs) @@ -176,6 +187,14 @@ def __init__( self.cross_entropy_loss = nn.CrossEntropyLoss() self.cache: list[list[Tensor]] | None = None self.random_states: list[list[RandContext]] | None = None + self.use_cont_accum = use_cont_accum + self.cache_size = 0 if cache_size is None else cache_size + self.prev_cache = prev_cache + self._query_cache: deque[Tensor] | None = None + self._candidate_caches: list[deque[Tensor]] | None = None + + if self.use_cont_accum and self.cache_size <= 0: + raise ValueError("cache_size must be a positive integer when use_cont_accum is True.") def embed_minibatch( self, @@ -229,20 +248,97 @@ def embed_minibatch_iter( ) yield reps, random_state # reps: (mbsz, hdim) - def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]]) -> Tensor: + def _init_cont_caches(self, num_candidate_columns: int) -> None: + if not self.use_cont_accum or self._candidate_caches is not None: + return + self._candidate_caches = [deque(maxlen=self.cache_size) for _ in range(num_candidate_columns)] + self._query_cache = deque(maxlen=self.cache_size) + + def _cached_tensor(self, cache: deque[Tensor] | None, device: torch.device) -> Tensor | None: + if cache is None or len(cache) == 0: + return None + return torch.cat(list(cache), dim=0).to(device) + + def _collect_cont_cache(self, reps: list[list[Tensor]]) -> tuple[Tensor | None, list[Tensor | None] | None]: + if not self.use_cont_accum or not self.training: + return None, None + self._init_cont_caches(num_candidate_columns=len(reps) - 1) + + anchors = torch.cat(reps[0]) + cached_anchors = self._cached_tensor(self._query_cache, anchors.device) + + cached_candidates: list[Tensor | None] = [] + assert self._candidate_caches is not None + for reps_col, cache in zip(reps[1:], self._candidate_caches): + candidates = torch.cat(reps_col) + cached_candidates.append(self._cached_tensor(cache, candidates.device)) + + return cached_anchors, cached_candidates + + def _enqueue_cont_cache(self, reps: list[list[Tensor]]) -> None: + if not self.use_cont_accum or not self.training: + return + if self._candidate_caches is None: + self._init_cont_caches(num_candidate_columns=len(reps) - 1) + assert self._candidate_caches is not None + + anchors = torch.cat(reps[0]) + self._query_cache.append(anchors.detach().clone()) + + for reps_col, cache in zip(reps[1:], self._candidate_caches): + candidates = torch.cat(reps_col) + cache.append(candidates.detach().clone()) + + def reset_cont_cache(self) -> None: + if self._query_cache is not None: + self._query_cache.clear() + if self._candidate_caches is not None: + for cache in self._candidate_caches: + cache.clear() + + def on_optimizer_step(self) -> None: + if self.use_cont_accum and not self.prev_cache: + self.reset_cont_cache() + + def calculate_loss_and_cache_gradients( + self, + reps: list[list[Tensor]], + cached_anchors: Tensor | None = None, + cached_candidates: list[Tensor | None] | None = None, + ) -> Tensor: """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" - loss = self.calculate_loss(reps, with_backward=True) + loss = self.calculate_loss( + reps, + cached_anchors=cached_anchors, + cached_candidates=cached_candidates, + with_backward=True, + ) loss = loss.detach().requires_grad_() self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) return loss - def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) -> Tensor: + def calculate_loss( + self, + reps: list[list[Tensor]], + cached_anchors: Tensor | None = None, + cached_candidates: list[Tensor | None] | None = None, + with_backward: bool = False, + ) -> Tensor: """Calculate the cross-entropy loss. No need to cache the gradients.""" anchors = torch.cat(reps[0]) # (batch_size, embedding_dim) + if cached_anchors is not None: + anchors = torch.cat([anchors, cached_anchors], dim=0) + candidates = [torch.cat(r) for r in reps[1:]] # (1 + num_neg) tensors of shape (batch_size, embedding_dim) - batch_size = len(anchors) + if cached_candidates is not None: + for idx, cached in enumerate(cached_candidates): + if cached is None: + continue + candidates[idx] = torch.cat([candidates[idx], cached], dim=0) + + batch_size = anchors.size(0) offset = 0 if self.gather_across_devices: @@ -300,15 +396,23 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor reps.append(reps_mbs) self.random_states.append(random_state_mbs) + cached_anchors, cached_candidates = self._collect_cont_cache(reps) + if torch.is_grad_enabled(): # Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings - loss = self.calculate_loss_and_cache_gradients(reps) + loss = self.calculate_loss_and_cache_gradients( + reps, + cached_anchors=cached_anchors, + cached_candidates=cached_candidates, + ) # Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self)) else: # If grad is not enabled (e.g. in evaluation), then we don't have to worry about the gradients or backward hook - loss = self.calculate_loss(reps) + loss = self.calculate_loss(reps, cached_anchors=cached_anchors, cached_candidates=cached_candidates) + + self._enqueue_cont_cache(reps) return loss @@ -318,6 +422,9 @@ def get_config_dict(self) -> dict[str, Any]: "similarity_fct": self.similarity_fct.__name__, "mini_batch_size": self.mini_batch_size, "gather_across_devices": self.gather_across_devices, + "use_cont_accum": self.use_cont_accum, + "cache_size": self.cache_size, + "prev_cache": self.prev_cache, } @property diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index c90d0fb49..7e45034c2 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -57,6 +57,14 @@ TrackioCallback = None +class _LossCacheResetCallback(TrainerCallback): + def __init__(self, trainer: SentenceTransformerTrainer) -> None: + self.trainer = trainer + + def on_step_end(self, args, state, control, **kwargs): + self.trainer._loss_on_optimizer_step() + + class SentenceTransformerTrainer(Trainer): """ SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch @@ -312,6 +320,8 @@ def __init__( else: self.loss = self.prepare_loss(loss, model) + self.add_callback(_LossCacheResetCallback(self)) + # If evaluator is a list, we wrap it in a SequentialEvaluator if evaluator is not None and not isinstance(evaluator, SentenceEvaluator): evaluator = SequentialEvaluator(evaluator) @@ -443,6 +453,12 @@ def compute_loss( return loss, {} return loss + def _loss_on_optimizer_step(self) -> None: + losses = self.loss.values() if isinstance(self.loss, dict) else [self.loss] + for loss_fn in losses: + if hasattr(loss_fn, "on_optimizer_step"): + loss_fn.on_optimizer_step() + def track_loss_components(self, loss: dict[str, torch.Tensor]) -> None: training_type = "train" if self.model.training else "eval" for key, value in loss.items(): diff --git a/tests/cross_encoder/test_cross_encoder.py b/tests/cross_encoder/test_cross_encoder.py index 31d06bab4..e1828df4e 100644 --- a/tests/cross_encoder/test_cross_encoder.py +++ b/tests/cross_encoder/test_cross_encoder.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import json import logging import re @@ -267,6 +268,11 @@ def test_push_to_hub( ) -> None: model = reranker_bert_tiny_model + def build_commit_info(**kwargs): + if "_endpoint" in inspect.signature(CommitInfo).parameters: + return CommitInfo("https://huggingface.co", **kwargs) + return CommitInfo(**kwargs) + def mock_create_repo(self, repo_id, **kwargs): return RepoUrl(f"https://huggingface.co/{repo_id}") @@ -279,7 +285,7 @@ def mock_upload_folder(self, **kwargs): revision = "123456" else: revision = "678901" - return CommitInfo( + return build_commit_info( commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/{revision}", commit_message="commit_message", commit_description="commit_description", diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index de5ce8fcf..bd1a0f467 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -5,6 +5,7 @@ from __future__ import annotations import copy +import inspect import json import logging import os @@ -112,6 +113,11 @@ def test_torch_dtype(torch_dtype) -> None: def test_push_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None: + def build_commit_info(**kwargs): + if "_endpoint" in inspect.signature(CommitInfo).parameters: + return CommitInfo("https://huggingface.co", **kwargs) + return CommitInfo(**kwargs) + def mock_create_repo(self, repo_id, **kwargs): return RepoUrl(f"https://huggingface.co/{repo_id}") @@ -121,14 +127,14 @@ def mock_upload_folder(self, **kwargs): nonlocal mock_upload_folder_kwargs mock_upload_folder_kwargs = kwargs if kwargs.get("revision") is None: - return CommitInfo( + return build_commit_info( commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/123456", commit_message="commit_message", commit_description="commit_description", oid="oid", ) else: - return CommitInfo( + return build_commit_info( commit_url=f"https://huggingface.co/{kwargs.get('repo_id')}/commit/678901", commit_message="commit_message", commit_description="commit_description",