Skip to content
89 changes: 83 additions & 6 deletions sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from collections.abc import Callable, Iterable, Iterator
from contextlib import nullcontext
from functools import partial
Expand All @@ -15,6 +16,8 @@
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import all_gather_with_grad

logger = logging.getLogger(__name__)


class RandContext:
"""
Expand Down Expand Up @@ -80,6 +83,8 @@ def __init__(
] = ("query_to_doc",),
partition_mode: Literal["joint", "per_direction"] = "joint",
show_progress_bar: bool = False,
hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None,
hardness_strength: float = 0.0,
) -> None:
"""
Boosted version of :class:`MultipleNegativesRankingLoss` (https://huggingface.co/papers/1705.00652) by GradCache (https://huggingface.co/papers/2101.06983).
Expand Down Expand Up @@ -130,6 +135,26 @@ def __init__(
- "per_direction": One softmax per direction. A loss is computed for each direction and then averaged.
Not compatible with ``"query_to_query"`` or ``"doc_to_doc"`` directions.
show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False.
hardness_mode: Strategy for applying hardness weighting. ``None`` (default) disables hardness
weighting entirely. Options:

- ``"in_batch_negatives"``: Adds ``hardness_strength * stop_grad(cos_sim)`` to every in-batch negative
logit inside the softmax (`Lan et al. 2025 <https://huggingface.co/papers/2503.04812>`_, Eq. 5). The
in-batch negatives are all positives and hard negatives from other samples in the batch.
- ``"hard_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` only to the logits of
explicit hard negatives, leaving in-batch negatives unpenalized. Only active when explicit
negatives are provided. As used in
`Lan et al. 2025 <https://huggingface.co/papers/2509.20354>`_ (EmbeddingGemma).
- ``"all_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` to every negative logit,
both in-batch negatives and explicit hard negatives, leaving only the positive unpenalized.
Combines the effect of ``"in_batch_negatives"`` and ``"hard_negatives"``.

hardness_strength: Strength of the hardness weighting. The meaning depends on ``hardness_mode``:

- For ``"in_batch_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 <https://huggingface.co/papers/2503.04812>`_ uses 9.
- For ``"hard_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 <https://huggingface.co/papers/2509.20354>`_ uses 5.

Must be non-negative. Ignored when ``hardness_mode`` is ``None``.

References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652
Expand Down Expand Up @@ -216,6 +241,19 @@ def __init__(
self.partition_mode = partition_mode
self.show_progress_bar = show_progress_bar

valid_hardness_modes = {None, "in_batch_negatives", "hard_negatives", "all_negatives"}
if hardness_mode not in valid_hardness_modes:
raise ValueError(f"hardness_mode must be one of {valid_hardness_modes}, got {hardness_mode!r}")
self.hardness_mode = hardness_mode
if hardness_strength < 0.0:
raise ValueError("hardness_strength must be non-negative.")
self.hardness_strength = hardness_strength
if hardness_mode is not None and hardness_strength == 0.0:
logger.warning(
f"hardness_mode={hardness_mode!r} is set but hardness_strength=0.0, so hardness weighting has no "
"effect. Set hardness_strength to a positive value to enable hardness weighting."
)

self.cache: list[list[Tensor]] | None = None
self.random_states: list[list[RandContext]] | None = None

Expand Down Expand Up @@ -325,25 +363,61 @@ def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False)

sim_matrices = {}
# (mbs, bs * ws * (1 + nn))
sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) * self.scale
sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all)

if "query_to_query" in self.directions:
# (mbs, bs * ws)
sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) * self.scale
sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries)
# Remove self-similarity entries q_i -> q_i
sim_matrices["query_to_query"][row_indices, local_batch] = -torch.inf

if "doc_to_query" in self.directions:
# (mbs, bs * ws)
sim_matrices["doc_to_query"] = (self.similarity_fct(queries, local_docs) * self.scale).T
sim_matrices["doc_to_query"] = self.similarity_fct(queries, local_docs).T

if "doc_to_doc" in self.directions:
# (mbs, bs * ws * (1 + nn))
sim_matrices["doc_to_doc"] = (self.similarity_fct(docs_all, local_docs) * self.scale).T
sim_matrices["doc_to_doc"] = self.similarity_fct(docs_all, local_docs).T
# Remove d_i_a -> d_i_b for all documents belonging to the same query
same_query_doc_mask = identity[local_batch].repeat(1, num_docs).bool()
sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf)

# Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5).
# penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the
# softmax denominator. Computed before temperature scaling so no rescaling is needed.
penalties = {}
if (
self.hardness_mode in ("in_batch_negatives", "hard_negatives", "all_negatives")
and self.hardness_strength > 0.0
):
penalty = self.hardness_strength * sim_matrices["query_to_doc"].detach()

# True where the document belongs to the same query (own positive + own hard negatives)
own_doc_mask = torch.eye(world_batch_size, device=queries.device, dtype=torch.bool)[local_batch]
own_doc_mask = own_doc_mask.repeat(1, num_docs)

if self.hardness_mode == "hard_negatives":
# Exclude positives and in-batch negatives, keeping only own hard negatives
penalty_exclusion_mask = ~own_doc_mask
penalty_exclusion_mask[:, :world_batch_size] = True
elif self.hardness_mode == "in_batch_negatives":
# Exclude own positives and hard negatives, keeping only in-batch negatives
penalty_exclusion_mask = own_doc_mask
elif self.hardness_mode == "all_negatives":
# Exclude positives only, keeping both in-batch and hard negatives
penalty_exclusion_mask = own_doc_mask
penalty_exclusion_mask[:, world_batch_size:] = False

penalty[penalty_exclusion_mask] = 0.0
penalties["query_to_doc"] = penalty

# Apply temperature scaling (scale = 1/temperature) and add hardness penalties.
# Final logit = cos_sim * scale + alpha * cos_sim (penalty is not temperature-scaled).
for key in sim_matrices:
sim_matrices[key] = sim_matrices[key] * self.scale
for key, pen in penalties.items():
sim_matrices[key] = sim_matrices[key] + pen

# Positive scores (always from query_to_doc)
positive_scores = sim_matrices["query_to_doc"][row_indices, local_batch]

Expand All @@ -358,8 +432,9 @@ def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False)
log_z += torch.logsumexp(sim_matrix, dim=1)
log_z /= len(sim_matrices)

loss_mbatch = -(positive_scores - log_z).mean()
loss_mbatch = loss_mbatch * len(local_batch) / batch_size
per_sample_loss = -(positive_scores - log_z)
loss_mbatch = per_sample_loss.mean() * len(local_batch) / batch_size

if with_backward:
loss_mbatch.backward()
loss_mbatch = loss_mbatch.detach()
Expand Down Expand Up @@ -408,6 +483,8 @@ def get_config_dict(self) -> dict[str, Any]:
"gather_across_devices": self.gather_across_devices,
"directions": self.directions,
"partition_mode": self.partition_mode,
"hardness_mode": self.hardness_mode,
"hardness_strength": self.hardness_strength,
}

@property
Expand Down
86 changes: 81 additions & 5 deletions sentence_transformers/losses/MultipleNegativesRankingLoss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from collections.abc import Callable, Iterable
from typing import Any, Literal

Expand All @@ -10,6 +11,8 @@
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import all_gather_with_grad

logger = logging.getLogger(__name__)


class MultipleNegativesRankingLoss(nn.Module):
def __init__(
Expand All @@ -23,6 +26,8 @@ def __init__(
...,
] = ("query_to_doc",),
partition_mode: Literal["joint", "per_direction"] = "joint",
hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None,
hardness_strength: float = 0.0,
) -> None:
"""
Given a dataset of (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n)
Expand Down Expand Up @@ -78,6 +83,27 @@ def __init__(
- "joint": One joint softmax over all selected directions.
- "per_direction": One softmax per direction. A loss is computed for each direction and then averaged.
Not compatible with ``"query_to_query"`` or ``"doc_to_doc"`` directions.
hardness_mode: Strategy for applying hardness weighting. ``None`` (default) disables hardness
weighting entirely. Options:

- ``"in_batch_negatives"``: Adds ``hardness_strength * stop_grad(cos_sim)`` to every in-batch negative
logit inside the softmax (`Lan et al. 2025 <https://huggingface.co/papers/2503.04812>`_, Eq. 5). The
in-batch negatives are all positives and hard negatives from other samples in the batch.
Works with all data formats including pairs-only.
- ``"hard_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` only to the logits of
explicit hard negatives, leaving in-batch negatives unpenalized. Only active when explicit
negatives are provided. As used in
`Lan et al. 2025 <https://huggingface.co/papers/2509.20354>`_ (EmbeddingGemma).
- ``"all_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` to every negative logit,
both in-batch negatives and explicit hard negatives, leaving only the positive unpenalized.
Combines the effect of ``"in_batch_negatives"`` and ``"hard_negatives"``.

hardness_strength: Strength of the hardness weighting. The meaning depends on ``hardness_mode``:

- For ``"in_batch_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 <https://huggingface.co/papers/2503.04812>`_ uses 9.
- For ``"hard_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 <https://huggingface.co/papers/2509.20354>`_ uses 5.

Must be non-negative. Ignored when ``hardness_mode`` is ``None``.

Requirements:
1. (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) n-tuples
Expand Down Expand Up @@ -188,6 +214,19 @@ def __init__(
)
self.partition_mode = partition_mode

valid_hardness_modes = {None, "in_batch_negatives", "hard_negatives", "all_negatives"}
if hardness_mode not in valid_hardness_modes:
raise ValueError(f"hardness_mode must be one of {valid_hardness_modes}, got {hardness_mode!r}")
self.hardness_mode = hardness_mode
if hardness_strength < 0.0:
raise ValueError("hardness_strength must be non-negative.")
self.hardness_strength = hardness_strength
if hardness_mode is not None and hardness_strength == 0.0:
logger.warning(
f"hardness_mode={hardness_mode!r} is set but hardness_strength=0.0, so hardness weighting has no "
"effect. Set hardness_strength to a positive value to enable hardness weighting."
)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
# Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives)
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
Expand Down Expand Up @@ -223,26 +262,62 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor)

sim_matrices = {}
# (bs, bs * ws * (1 + nn))
sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) * self.scale
sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all)

if "query_to_query" in self.directions:
# (bs, bs * ws)
sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) * self.scale
sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries)
# Remove self-similarity entries q_i -> q_i
sim_matrices["query_to_query"][row_indices, local_indices] = -torch.inf

if "doc_to_query" in self.directions:
# (bs, bs * ws)
sim_matrices["doc_to_query"] = (self.similarity_fct(queries, local_docs) * self.scale).T
sim_matrices["doc_to_query"] = self.similarity_fct(queries, local_docs).T

if "doc_to_doc" in self.directions:
# (bs, bs * ws * (1 + nn))
sim_matrices["doc_to_doc"] = (self.similarity_fct(docs_all, local_docs) * self.scale).T
sim_matrices["doc_to_doc"] = self.similarity_fct(docs_all, local_docs).T
# Remove d_i_a -> d_i_b for all documents belonging to the same query
same_query_doc_mask = torch.eye(world_batch_size, device=queries.device)[local_indices]
same_query_doc_mask = same_query_doc_mask.repeat(1, len(docs)).bool()
sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf)

# Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5).
# penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the
# softmax denominator. Computed before temperature scaling so no rescaling is needed.
penalties = {}
if (
self.hardness_mode in ("in_batch_negatives", "hard_negatives", "all_negatives")
and self.hardness_strength > 0.0
):
penalty = self.hardness_strength * sim_matrices["query_to_doc"].detach()

# True where the document belongs to the same query (own positive + own hard negatives)
own_doc_mask = torch.eye(world_batch_size, device=queries.device, dtype=torch.bool)[local_indices]
own_doc_mask = own_doc_mask.repeat(1, len(docs))

if self.hardness_mode == "hard_negatives":
# Exclude positives and in-batch negatives, keeping only own hard negatives
penalty_exclusion_mask = ~own_doc_mask
penalty_exclusion_mask[:, :world_batch_size] = True
elif self.hardness_mode == "in_batch_negatives":
# Exclude own positives and hard negatives, keeping only in-batch negatives
penalty_exclusion_mask = own_doc_mask
elif self.hardness_mode == "all_negatives":
# Exclude positives only, keeping both in-batch and hard negatives
penalty_exclusion_mask = own_doc_mask
penalty_exclusion_mask[:, world_batch_size:] = False

penalty[penalty_exclusion_mask] = 0.0
penalties["query_to_doc"] = penalty

# Apply temperature scaling (scale = 1/temperature) and add hardness penalties.
# Final logit = cos_sim * scale + alpha * cos_sim (penalty is not temperature-scaled).
for key in sim_matrices:
sim_matrices[key] = sim_matrices[key] * self.scale
for key, pen in penalties.items():
sim_matrices[key] = sim_matrices[key] + pen

# Positive scores (always from query_to_doc)
positive_scores = sim_matrices["query_to_doc"][row_indices, local_indices]

Expand All @@ -259,7 +334,6 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor)
log_z /= len(sim_matrices)

loss = -(positive_scores - log_z).mean()

return loss

def get_config_dict(self) -> dict[str, Any]:
Expand All @@ -269,6 +343,8 @@ def get_config_dict(self) -> dict[str, Any]:
"gather_across_devices": self.gather_across_devices,
"directions": self.directions,
"partition_mode": self.partition_mode,
"hardness_mode": self.hardness_mode,
"hardness_strength": self.hardness_strength,
}

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
...,
] = ("query_to_doc",),
partition_mode: Literal["joint", "per_direction"] = "joint",
hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None,
hardness_strength: float = 0.0,
) -> None:
"""
Given a dataset of (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n)
Expand Down Expand Up @@ -132,6 +134,8 @@ def __init__(
gather_across_devices=gather_across_devices,
directions=directions,
partition_mode=partition_mode,
hardness_mode=hardness_mode,
hardness_strength=hardness_strength,
)

def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
Expand Down
Loading