Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/package_reference/sentence_transformer/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ This allows our network to be fine-tuned to recognize the similarity of sentence
.. autoclass:: sentence_transformers.losses.CachedGISTEmbedLoss
```

## GlobalOrthogonalRegularizationLoss

```{eval-rst}
.. autoclass:: sentence_transformers.losses.GlobalOrthogonalRegularizationLoss
```

## MSELoss

```{eval-rst}
Expand Down
8 changes: 8 additions & 0 deletions docs/sentence_transformer/loss_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ For example, models trained with <a href="../package_reference/sentence_transfor
|-------|--------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `any` | `any` | <a href="../package_reference/sentence_transformer/losses.html#matryoshkaloss">`MatryoshkaLoss`</a><br><a href="../package_reference/sentence_transformer/losses.html#adaptivelayerloss">`AdaptiveLayerLoss`</a><br><a href="../package_reference/sentence_transformer/losses.html#matryoshka2dloss">`Matryoshka2dLoss`</a> |

## Regularization

These losses are designed to regularize the embedding space during training, encouraging certain properties in the learned embeddings. They can often be applied to any dataset configuration.

| Texts | Labels | Appropriate Loss Functions |
|-------|--------|---------------------------------------------------------------------------------------------------------------------------------------------|
| `any` | `none` | <a href="../package_reference/sentence_transformer/losses.html#globalorthogonalregularizationloss">`GlobalOrthogonalRegularizationLoss`</a> |

## Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another.
For example, when finetuning a small model to behave more like a larger & stronger one, or when finetuning a model to become multi-lingual.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
This script trains a sentence transformer using a combination of InfoNCE (via MultipleNegativesRankingLoss)
and Global Orthogonal Regularization (GOR) loss. The combination helps learn embeddings that are both
discriminative and well-distributed in the embedding space.

The script uses the GooAQ dataset (https://huggingface.co/datasets/sentence-transformers/gooaq), which contains
question-answer pairs from Google's "People Also Ask" feature. The model learns to encode questions and answers
such that matching pairs are close in embedding space.

Usage:
python training_gooaq_infonce_gor.py
"""

import logging
import random
import traceback
from collections.abc import Iterable

import torch
from datasets import Dataset, load_dataset
from torch import Tensor

from sentence_transformers import (
SentenceTransformer,
SentenceTransformerModelCardData,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses.GlobalOrthogonalRegularizationLoss import GlobalOrthogonalRegularizationLoss
from sentence_transformers.losses.MultipleNegativesRankingLoss import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.util import cos_sim

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


# Define a custom loss that combines InfoNCE and Global Orthogonal Regularization
class InfoNCEGORLoss(torch.nn.Module):
"""
Combines MultipleNegativesRankingLoss (InfoNCE) with Global Orthogonal Regularization Loss.

This loss encourages the model to:
1. Learn discriminative embeddings where positive pairs are closer than negative pairs (InfoNCE)
2. Distribute embeddings more evenly across the embedding space to avoid mode collapse (GOR)
"""

def __init__(self, model: SentenceTransformer, similarity_fct=cos_sim, scale=20.0) -> None:
super().__init__()
self.model = model
self.info_nce_loss = MultipleNegativesRankingLoss(model, similarity_fct=similarity_fct, scale=scale)
self.gor_loss = GlobalOrthogonalRegularizationLoss(model, similarity_fct=similarity_fct)

def forward(
self,
sentence_features: Iterable[dict[str, Tensor]],
labels: Tensor | None = None,
) -> dict[str, Tensor]:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
info_nce_loss: dict[str, Tensor] = {
"info_nce": self.info_nce_loss.compute_loss_from_embeddings(embeddings, labels)
}
gor_loss: dict[str, Tensor] = self.gor_loss.compute_loss_from_embeddings(embeddings, labels)
return {**info_nce_loss, **gor_loss}


# Model and training parameters
model_name = "microsoft/mpnet-base"
num_train_samples = 100_000
num_eval_samples = 10_000
train_batch_size = 64
num_epochs = 1

# 1. Load a model to finetune with optional model card data
logging.info(f"Loading model: {model_name}")
model = SentenceTransformer(
model_name,
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on GooAQ using InfoNCE + Global Orthogonal Regularization",
),
)

# 2. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Loading GooAQ dataset")
dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(num_train_samples))
dataset = dataset.add_column("id", range(len(dataset)))
dataset_dict = dataset.train_test_split(test_size=num_eval_samples, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]
logging.info(f"Train dataset size: {len(train_dataset)}")
logging.info(f"Eval dataset size: {len(eval_dataset)}")

# 3. Define the loss function
loss = InfoNCEGORLoss(model)

# 4. (Optional) Create an evaluator for use during training
# We create a small corpus for evaluation to measure retrieval performance
logging.info("Creating evaluation corpus")
random.seed(12)
queries = dict(zip(eval_dataset["id"], eval_dataset["question"]))
# Use only the answers that correspond to the evaluation queries for a focused evaluation
corpus = {qid: dataset[qid]["answer"] for qid in queries}
relevant_docs = {qid: {qid} for qid in eval_dataset["id"]}

dev_evaluator = InformationRetrievalEvaluator(
corpus=corpus,
queries=queries,
relevant_docs=relevant_docs,
show_progress_bar=True,
name="gooaq-dev",
)

# Evaluate the base model before training
logging.info("Performance before fine-tuning:")
dev_evaluator(model)

# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"{short_model_name}-gooaq-infonce-gor"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
# Use NO_DUPLICATES to ensure each batch has unique samples, which benefits MultipleNegativesRankingLoss
batch_sampler=BatchSamplers.NO_DUPLICATES,
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=50,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)

# 6. Create a trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset.remove_columns("id"),
eval_dataset=eval_dataset.remove_columns("id"),
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()

# 7. Evaluate the trained model on the test set
logging.info("Evaluating trained model")
dev_evaluator(model)

# 8. Save the trained & evaluated model locally
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)

# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
148 changes: 148 additions & 0 deletions sentence_transformers/losses/GlobalOrthogonalRegularizationLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Literal

import torch
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import cos_sim


class GlobalOrthogonalRegularizationLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
similarity_fct=cos_sim,
mean_weight: float = 1.0,
second_moment_weight: float = 1.0,
aggregation: Literal["mean", "sum"] = "mean",
) -> None:
"""
Global Orthogonal Regularization (GOR) Loss that encourages embeddings to be well-distributed
in the embedding space by penalizing high mean similarities and high second moments of similarities
across unrelated inputs.

The loss consists of two terms:

1. Mean term: Penalizes when the mean similarity across unrelated embeddings is high
2. Second moment term: Penalizes when the second moment of similarities is high

A high second moment indicates that some embeddings have very high similarities, suggesting clustering
or concentration in certain regions of the embedding space. A low second moment indicates that
similarities are more uniformly distributed.

The loss is called independently on each input column (e.g., queries and passages) and combines the results
using either mean or sum aggregation. This is why the loss can be used on any dataset configuration
(e.g., single inputs, pairs, triplets, etc.).

It's recommended to combine this loss with a primary loss function, such as :class:`MultipleNegativesRankingLoss`.

Args:
model: SentenceTransformer model
similarity_fct: Function to compute similarity between embeddings (default: cosine similarity)
mean_weight: Weight for the mean term loss component (default: 1.0)
second_moment_weight: Weight for the second moment term loss component (default: 1.0)
aggregation: How to combine losses across input columns. Either "mean" or "sum" (default: "mean").
The EmbeddingGemma paper uses "sum".

References:
- For further details, see: https://arxiv.org/abs/1708.06320 or https://arxiv.org/abs/2509.20354.
The latter paper uses the equivalent of GOR with ``mean_weight=0.0`` and ``aggregation="sum"``.

Inputs:
+-------+--------+
| Texts | Labels |
+=======+========+
| any | none |
+-------+--------+
"""
super().__init__()
self.model = model
self.similarity_fct = similarity_fct
self.mean_weight = mean_weight
self.second_moment_weight = second_moment_weight
if aggregation not in ["mean", "sum"]:
raise ValueError(f"aggregation must be 'mean' or 'sum', got '{aggregation}'")
self.aggregation = aggregation

def forward(
self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor | None = None
) -> dict[str, Tensor]:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
return self.compute_loss_from_embeddings(embeddings)

def compute_loss_from_embeddings(
self, embeddings: list[Tensor], labels: Tensor | None = None
) -> dict[str, Tensor]:
"""
Compute the GOR loss from pre-computed embeddings.

Args:
embeddings: List of embedding tensors, one for each input column (e.g., [queries, passages])
labels: Not used, kept for compatibility

Returns:
Dictionary containing the weighted mean term and second moment term losses
"""
mean_terms, second_moment_terms = zip(*[self.compute_gor(embedding) for embedding in embeddings])
results = {}
if self.mean_weight:
stacked_mean = torch.stack(mean_terms)
aggregated_mean = stacked_mean.sum() if self.aggregation == "sum" else stacked_mean.mean()
results["gor_mean"] = self.mean_weight * aggregated_mean
if self.second_moment_weight:
stacked_second_moment = torch.stack(second_moment_terms)
aggregated_second_moment = (
stacked_second_moment.sum() if self.aggregation == "sum" else stacked_second_moment.mean()
)
results["gor_second_moment"] = self.second_moment_weight * aggregated_second_moment
return results

def compute_gor(self, embeddings: Tensor) -> tuple[Tensor, Tensor]:
"""
Compute the Global Orthogonal Regularization terms for a batch of embeddings.

The GOR loss encourages embeddings to be well-distributed by:
1. Mean term (M_1^2): Penalizes high mean similarity, pushing embeddings apart
2. Second moment term (M_2 - 1/d): Penalizes when the second moment exceeds 1/d, encouraging uniform distribution

Args:
embeddings: Tensor of shape (batch_size, embedding_dim)

Returns:
Tuple of (mean_term, second_moment_term) losses (unweighted)
"""
batch_size = embeddings.size(0)
hidden_dim = embeddings.size(1)

# Compute pairwise similarity matrix between all embeddings, and exclude self-similarities
sim_matrix = self.similarity_fct(embeddings, embeddings)
sim_matrix.fill_diagonal_(0.0)
num_off_diagonal = batch_size * (batch_size - 1)

# Mean term: M_1^2 where M_1 = mean of off-diagonal similarities
# Penalizes high similarities across inputs from the same column (e.g., queries vs other queries)
mean_term = (sim_matrix.sum() / num_off_diagonal).pow(2)

# Second moment term: M_2 - 1/d where M_2 = mean of squared off-diagonal similarities and d is embedding dimension
# Pushes the second moment close to 1/d, encouraging a more uniform distribution
second_moment = sim_matrix.pow(2).sum() / num_off_diagonal
second_moment_term = torch.relu(second_moment - (1.0 / hidden_dim))

return mean_term, second_moment_term

@property
def citation(self) -> str:
return """
@misc{zhang2017learningspreadoutlocalfeature,
title={Learning Spread-out Local Feature Descriptors},
author={Xu Zhang and Felix X. Yu and Sanjiv Kumar and Shih-Fu Chang},
year={2017},
eprint={1708.06320},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/1708.06320},
}
"""
2 changes: 2 additions & 0 deletions sentence_transformers/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .DenoisingAutoEncoderLoss import DenoisingAutoEncoderLoss
from .DistillKLDivLoss import DistillKLDivLoss
from .GISTEmbedLoss import GISTEmbedLoss
from .GlobalOrthogonalRegularizationLoss import GlobalOrthogonalRegularizationLoss
from .MarginMSELoss import MarginMSELoss
from .Matryoshka2dLoss import Matryoshka2dLoss
from .MatryoshkaLoss import MatryoshkaLoss
Expand Down Expand Up @@ -64,6 +65,7 @@
"MegaBatchMarginLoss",
"DenoisingAutoEncoderLoss",
"GISTEmbedLoss",
"GlobalOrthogonalRegularizationLoss",
"BatchHardTripletLoss",
"BatchHardTripletLossDistanceFunction",
"BatchHardSoftMarginTripletLoss",
Expand Down
Loading