diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index 500d9da6f..54e6f20c1 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from collections.abc import Callable, Iterable, Iterator from contextlib import nullcontext from functools import partial @@ -15,6 +16,8 @@ from sentence_transformers.SentenceTransformer import SentenceTransformer from sentence_transformers.util import all_gather_with_grad +logger = logging.getLogger(__name__) + class RandContext: """ @@ -80,6 +83,8 @@ def __init__( ] = ("query_to_doc",), partition_mode: Literal["joint", "per_direction"] = "joint", show_progress_bar: bool = False, + hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None, + hardness_strength: float = 0.0, ) -> None: """ Boosted version of :class:`MultipleNegativesRankingLoss` (https://huggingface.co/papers/1705.00652) by GradCache (https://huggingface.co/papers/2101.06983). @@ -130,6 +135,26 @@ def __init__( - "per_direction": One softmax per direction. A loss is computed for each direction and then averaged. Not compatible with ``"query_to_query"`` or ``"doc_to_doc"`` directions. show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False. + hardness_mode: Strategy for applying hardness weighting. ``None`` (default) disables hardness + weighting entirely. Options: + + - ``"in_batch_negatives"``: Adds ``hardness_strength * stop_grad(cos_sim)`` to every in-batch negative + logit inside the softmax (`Lan et al. 2025 `_, Eq. 5). The + in-batch negatives are all positives and hard negatives from other samples in the batch. + - ``"hard_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` only to the logits of + explicit hard negatives, leaving in-batch negatives unpenalized. Only active when explicit + negatives are provided. As used in + `Lan et al. 2025 `_ (EmbeddingGemma). + - ``"all_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` to every negative logit, + both in-batch negatives and explicit hard negatives, leaving only the positive unpenalized. + Combines the effect of ``"in_batch_negatives"`` and ``"hard_negatives"``. + + hardness_strength: Strength of the hardness weighting. The meaning depends on ``hardness_mode``: + + - For ``"in_batch_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 `_ uses 9. + - For ``"hard_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 `_ uses 5. + + Must be non-negative. Ignored when ``hardness_mode`` is ``None``. References: - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://huggingface.co/papers/1705.00652 @@ -216,6 +241,19 @@ def __init__( self.partition_mode = partition_mode self.show_progress_bar = show_progress_bar + valid_hardness_modes = {None, "in_batch_negatives", "hard_negatives", "all_negatives"} + if hardness_mode not in valid_hardness_modes: + raise ValueError(f"hardness_mode must be one of {valid_hardness_modes}, got {hardness_mode!r}") + self.hardness_mode = hardness_mode + if hardness_strength < 0.0: + raise ValueError("hardness_strength must be non-negative.") + self.hardness_strength = hardness_strength + if hardness_mode is not None and hardness_strength == 0.0: + logger.warning( + f"hardness_mode={hardness_mode!r} is set but hardness_strength=0.0, so hardness weighting has no " + "effect. Set hardness_strength to a positive value to enable hardness weighting." + ) + self.cache: list[list[Tensor]] | None = None self.random_states: list[list[RandContext]] | None = None @@ -325,25 +363,61 @@ def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) sim_matrices = {} # (mbs, bs * ws * (1 + nn)) - sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) * self.scale + sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) if "query_to_query" in self.directions: # (mbs, bs * ws) - sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) * self.scale + sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) # Remove self-similarity entries q_i -> q_i sim_matrices["query_to_query"][row_indices, local_batch] = -torch.inf if "doc_to_query" in self.directions: # (mbs, bs * ws) - sim_matrices["doc_to_query"] = (self.similarity_fct(queries, local_docs) * self.scale).T + sim_matrices["doc_to_query"] = self.similarity_fct(queries, local_docs).T if "doc_to_doc" in self.directions: # (mbs, bs * ws * (1 + nn)) - sim_matrices["doc_to_doc"] = (self.similarity_fct(docs_all, local_docs) * self.scale).T + sim_matrices["doc_to_doc"] = self.similarity_fct(docs_all, local_docs).T # Remove d_i_a -> d_i_b for all documents belonging to the same query same_query_doc_mask = identity[local_batch].repeat(1, num_docs).bool() sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf) + # Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5). + # penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the + # softmax denominator. Computed before temperature scaling so no rescaling is needed. + penalties = {} + if ( + self.hardness_mode in ("in_batch_negatives", "hard_negatives", "all_negatives") + and self.hardness_strength > 0.0 + ): + penalty = self.hardness_strength * sim_matrices["query_to_doc"].detach() + + # True where the document belongs to the same query (own positive + own hard negatives) + own_doc_mask = torch.eye(world_batch_size, device=queries.device, dtype=torch.bool)[local_batch] + own_doc_mask = own_doc_mask.repeat(1, num_docs) + + if self.hardness_mode == "hard_negatives": + # Exclude positives and in-batch negatives, keeping only own hard negatives + penalty_exclusion_mask = ~own_doc_mask + penalty_exclusion_mask[:, :world_batch_size] = True + elif self.hardness_mode == "in_batch_negatives": + # Exclude own positives and hard negatives, keeping only in-batch negatives + penalty_exclusion_mask = own_doc_mask + elif self.hardness_mode == "all_negatives": + # Exclude positives only, keeping both in-batch and hard negatives + penalty_exclusion_mask = own_doc_mask + penalty_exclusion_mask[:, world_batch_size:] = False + + penalty[penalty_exclusion_mask] = 0.0 + penalties["query_to_doc"] = penalty + + # Apply temperature scaling (scale = 1/temperature) and add hardness penalties. + # Final logit = cos_sim * scale + alpha * cos_sim (penalty is not temperature-scaled). + for key in sim_matrices: + sim_matrices[key] = sim_matrices[key] * self.scale + for key, pen in penalties.items(): + sim_matrices[key] = sim_matrices[key] + pen + # Positive scores (always from query_to_doc) positive_scores = sim_matrices["query_to_doc"][row_indices, local_batch] @@ -358,8 +432,9 @@ def calculate_loss(self, reps: list[list[Tensor]], with_backward: bool = False) log_z += torch.logsumexp(sim_matrix, dim=1) log_z /= len(sim_matrices) - loss_mbatch = -(positive_scores - log_z).mean() - loss_mbatch = loss_mbatch * len(local_batch) / batch_size + per_sample_loss = -(positive_scores - log_z) + loss_mbatch = per_sample_loss.mean() * len(local_batch) / batch_size + if with_backward: loss_mbatch.backward() loss_mbatch = loss_mbatch.detach() @@ -408,6 +483,8 @@ def get_config_dict(self) -> dict[str, Any]: "gather_across_devices": self.gather_across_devices, "directions": self.directions, "partition_mode": self.partition_mode, + "hardness_mode": self.hardness_mode, + "hardness_strength": self.hardness_strength, } @property diff --git a/sentence_transformers/losses/MultipleNegativesRankingLoss.py b/sentence_transformers/losses/MultipleNegativesRankingLoss.py index 4d4a1cfa3..dd5f21883 100644 --- a/sentence_transformers/losses/MultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/MultipleNegativesRankingLoss.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from collections.abc import Callable, Iterable from typing import Any, Literal @@ -10,6 +11,8 @@ from sentence_transformers.SentenceTransformer import SentenceTransformer from sentence_transformers.util import all_gather_with_grad +logger = logging.getLogger(__name__) + class MultipleNegativesRankingLoss(nn.Module): def __init__( @@ -23,6 +26,8 @@ def __init__( ..., ] = ("query_to_doc",), partition_mode: Literal["joint", "per_direction"] = "joint", + hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None, + hardness_strength: float = 0.0, ) -> None: """ Given a dataset of (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) @@ -78,6 +83,27 @@ def __init__( - "joint": One joint softmax over all selected directions. - "per_direction": One softmax per direction. A loss is computed for each direction and then averaged. Not compatible with ``"query_to_query"`` or ``"doc_to_doc"`` directions. + hardness_mode: Strategy for applying hardness weighting. ``None`` (default) disables hardness + weighting entirely. Options: + + - ``"in_batch_negatives"``: Adds ``hardness_strength * stop_grad(cos_sim)`` to every in-batch negative + logit inside the softmax (`Lan et al. 2025 `_, Eq. 5). The + in-batch negatives are all positives and hard negatives from other samples in the batch. + Works with all data formats including pairs-only. + - ``"hard_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` only to the logits of + explicit hard negatives, leaving in-batch negatives unpenalized. Only active when explicit + negatives are provided. As used in + `Lan et al. 2025 `_ (EmbeddingGemma). + - ``"all_negatives"``: Applies ``hardness_strength * stop_grad(cos_sim)`` to every negative logit, + both in-batch negatives and explicit hard negatives, leaving only the positive unpenalized. + Combines the effect of ``"in_batch_negatives"`` and ``"hard_negatives"``. + + hardness_strength: Strength of the hardness weighting. The meaning depends on ``hardness_mode``: + + - For ``"in_batch_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 `_ uses 9. + - For ``"hard_negatives"``: acts as ``alpha`` in the hardness penalty, `Lan et al. 2025 `_ uses 5. + + Must be non-negative. Ignored when ``hardness_mode`` is ``None``. Requirements: 1. (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) n-tuples @@ -188,6 +214,19 @@ def __init__( ) self.partition_mode = partition_mode + valid_hardness_modes = {None, "in_batch_negatives", "hard_negatives", "all_negatives"} + if hardness_mode not in valid_hardness_modes: + raise ValueError(f"hardness_mode must be one of {valid_hardness_modes}, got {hardness_mode!r}") + self.hardness_mode = hardness_mode + if hardness_strength < 0.0: + raise ValueError("hardness_strength must be non-negative.") + self.hardness_strength = hardness_strength + if hardness_mode is not None and hardness_strength == 0.0: + logger.warning( + f"hardness_mode={hardness_mode!r} is set but hardness_strength=0.0, so hardness weighting has no " + "effect. Set hardness_strength to a positive value to enable hardness weighting." + ) + def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: # Compute the embeddings and distribute them to anchor and candidates (positive and optionally negatives) embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] @@ -223,26 +262,62 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) sim_matrices = {} # (bs, bs * ws * (1 + nn)) - sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) * self.scale + sim_matrices["query_to_doc"] = self.similarity_fct(local_queries, docs_all) if "query_to_query" in self.directions: # (bs, bs * ws) - sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) * self.scale + sim_matrices["query_to_query"] = self.similarity_fct(local_queries, queries) # Remove self-similarity entries q_i -> q_i sim_matrices["query_to_query"][row_indices, local_indices] = -torch.inf if "doc_to_query" in self.directions: # (bs, bs * ws) - sim_matrices["doc_to_query"] = (self.similarity_fct(queries, local_docs) * self.scale).T + sim_matrices["doc_to_query"] = self.similarity_fct(queries, local_docs).T if "doc_to_doc" in self.directions: # (bs, bs * ws * (1 + nn)) - sim_matrices["doc_to_doc"] = (self.similarity_fct(docs_all, local_docs) * self.scale).T + sim_matrices["doc_to_doc"] = self.similarity_fct(docs_all, local_docs).T # Remove d_i_a -> d_i_b for all documents belonging to the same query same_query_doc_mask = torch.eye(world_batch_size, device=queries.device)[local_indices] same_query_doc_mask = same_query_doc_mask.repeat(1, len(docs)).bool() sim_matrices["doc_to_doc"].masked_fill_(same_query_doc_mask, -torch.inf) + # Compute hardness penalties on the unscaled (raw cosine) similarities (Lan et al. 2025, Eq. 5). + # penalty = alpha * stop_grad(cos_sim), making harder negatives contribute more to the + # softmax denominator. Computed before temperature scaling so no rescaling is needed. + penalties = {} + if ( + self.hardness_mode in ("in_batch_negatives", "hard_negatives", "all_negatives") + and self.hardness_strength > 0.0 + ): + penalty = self.hardness_strength * sim_matrices["query_to_doc"].detach() + + # True where the document belongs to the same query (own positive + own hard negatives) + own_doc_mask = torch.eye(world_batch_size, device=queries.device, dtype=torch.bool)[local_indices] + own_doc_mask = own_doc_mask.repeat(1, len(docs)) + + if self.hardness_mode == "hard_negatives": + # Exclude positives and in-batch negatives, keeping only own hard negatives + penalty_exclusion_mask = ~own_doc_mask + penalty_exclusion_mask[:, :world_batch_size] = True + elif self.hardness_mode == "in_batch_negatives": + # Exclude own positives and hard negatives, keeping only in-batch negatives + penalty_exclusion_mask = own_doc_mask + elif self.hardness_mode == "all_negatives": + # Exclude positives only, keeping both in-batch and hard negatives + penalty_exclusion_mask = own_doc_mask + penalty_exclusion_mask[:, world_batch_size:] = False + + penalty[penalty_exclusion_mask] = 0.0 + penalties["query_to_doc"] = penalty + + # Apply temperature scaling (scale = 1/temperature) and add hardness penalties. + # Final logit = cos_sim * scale + alpha * cos_sim (penalty is not temperature-scaled). + for key in sim_matrices: + sim_matrices[key] = sim_matrices[key] * self.scale + for key, pen in penalties.items(): + sim_matrices[key] = sim_matrices[key] + pen + # Positive scores (always from query_to_doc) positive_scores = sim_matrices["query_to_doc"][row_indices, local_indices] @@ -259,7 +334,6 @@ def compute_loss_from_embeddings(self, embeddings: list[Tensor], labels: Tensor) log_z /= len(sim_matrices) loss = -(positive_scores - log_z).mean() - return loss def get_config_dict(self) -> dict[str, Any]: @@ -269,6 +343,8 @@ def get_config_dict(self) -> dict[str, Any]: "gather_across_devices": self.gather_across_devices, "directions": self.directions, "partition_mode": self.partition_mode, + "hardness_mode": self.hardness_mode, + "hardness_strength": self.hardness_strength, } @property diff --git a/sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py b/sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py index 91c66ab31..a8c303b69 100644 --- a/sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py +++ b/sentence_transformers/sparse_encoder/losses/SparseMultipleNegativesRankingLoss.py @@ -22,6 +22,8 @@ def __init__( ..., ] = ("query_to_doc",), partition_mode: Literal["joint", "per_direction"] = "joint", + hardness_mode: Literal["in_batch_negatives", "hard_negatives", "all_negatives"] | None = None, + hardness_strength: float = 0.0, ) -> None: """ Given a dataset of (anchor, positive) pairs, (anchor, positive, negative) triplets, or (anchor, positive, negative_1, ..., negative_n) @@ -132,6 +134,8 @@ def __init__( gather_across_devices=gather_across_devices, directions=directions, partition_mode=partition_mode, + hardness_mode=hardness_mode, + hardness_strength=hardness_strength, ) def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: diff --git a/tests/sparse_encoder/test_model_card.py b/tests/sparse_encoder/test_model_card.py index 9852a01e2..dfed68575 100644 --- a/tests/sparse_encoder/test_model_card.py +++ b/tests/sparse_encoder/test_model_card.py @@ -58,7 +58,7 @@ def dummy_dataset(): "| details |
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|", " | anchor 1 | positive 1 | negative 1 |", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ( @@ -68,7 +68,7 @@ def dummy_dataset(): "This is a [SPLADE Sparse Encoder](https://www.sbert.net/docs/sparse_encoder/usage/usage.html) model finetuned from [sparse-encoder-testing/splade-bert-tiny-nq](https://huggingface.co/sparse-encoder-testing/splade-bert-tiny-nq) on the train_0 dataset using the [sentence-transformers](https://www.SBERT.net) library.", "#### train_0", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ( @@ -79,7 +79,7 @@ def dummy_dataset(): "#### train_0", "#### train_1", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ( @@ -92,7 +92,7 @@ def dummy_dataset(): "\n
train_9", "#### train_9", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), # We start using "50 datasets" when the ", "-joined dataset name exceed 200 characters @@ -106,7 +106,7 @@ def dummy_dataset(): "
\n
train_49", "#### train_49", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ( @@ -125,7 +125,7 @@ def dummy_dataset(): "| details |
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|", " | anchor 1 | positive 1 | negative 1 |", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ( @@ -146,7 +146,7 @@ def dummy_dataset(): "| details |
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|
  • min: 4 tokens
  • mean: 4.0 tokens
  • max: 4 tokens
|", " | anchor 1 | positive 1 | negative 1 |", "* Loss: [SpladeLoss](https://sbert.net/docs/package_reference/sparse_encoder/losses.html#spladeloss) with these parameters:", - ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\')",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', + ' ```json\n {\n "loss": "SparseMultipleNegativesRankingLoss(scale=1.0, similarity_fct=\'dot_score\', gather_across_devices=False, directions=(\'query_to_doc\',), partition_mode=\'joint\', hardness_mode=None, hardness_strength=0.0)",\n "document_regularizer_weight": 3e-05,\n "query_regularizer_weight": 5e-05\n }\n ```', ], ), ],