-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[feat] Introduce GlobalOrthogonalRegularizationLoss
#3654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
87ea682
Introduce GlobalOrthogonalRegularizationLoss
tomaarsen e4f4318
If weight is 0.0 or None, exclude from output results
tomaarsen 0d12661
Add aggregation parameter
tomaarsen 7e319ad
Remove old output_dir
tomaarsen 5ce8bc3
Fix InfoNCEGORLoss typing
tomaarsen 4212037
Update docs slightly
tomaarsen 3bf335b
Add example, get_config_dict, address comments
tomaarsen ff661d5
Add second example
tomaarsen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
examples/sentence_transformer/training/other/training_gooaq_infonce_gor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| """ | ||
| This script trains a sentence transformer using a combination of InfoNCE (via MultipleNegativesRankingLoss) | ||
| and Global Orthogonal Regularization (GOR) loss. The combination helps learn embeddings that are both | ||
| discriminative and well-distributed in the embedding space. | ||
|
|
||
| The script uses the GooAQ dataset (https://huggingface.co/datasets/sentence-transformers/gooaq), which contains | ||
| question-answer pairs from Google's "People Also Ask" feature. The model learns to encode questions and answers | ||
| such that matching pairs are close in embedding space. | ||
|
|
||
| Usage: | ||
| python training_gooaq_infonce_gor.py | ||
| """ | ||
|
|
||
| import logging | ||
| import random | ||
| import traceback | ||
| from collections.abc import Iterable | ||
|
|
||
| import torch | ||
| from datasets import Dataset, load_dataset | ||
| from torch import Tensor | ||
|
|
||
| from sentence_transformers import ( | ||
| SentenceTransformer, | ||
| SentenceTransformerModelCardData, | ||
| SentenceTransformerTrainer, | ||
| SentenceTransformerTrainingArguments, | ||
| ) | ||
| from sentence_transformers.evaluation import InformationRetrievalEvaluator | ||
| from sentence_transformers.losses.GlobalOrthogonalRegularizationLoss import GlobalOrthogonalRegularizationLoss | ||
| from sentence_transformers.losses.MultipleNegativesRankingLoss import MultipleNegativesRankingLoss | ||
| from sentence_transformers.training_args import BatchSamplers | ||
| from sentence_transformers.util import cos_sim | ||
|
|
||
| # Set the log level to INFO to get more information | ||
| logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | ||
|
|
||
|
|
||
| # Define a custom loss that combines InfoNCE and Global Orthogonal Regularization | ||
| class InfoNCEGORLoss(torch.nn.Module): | ||
| """ | ||
| Combines MultipleNegativesRankingLoss (InfoNCE) with Global Orthogonal Regularization Loss. | ||
|
|
||
| This loss encourages the model to: | ||
| 1. Learn discriminative embeddings where positive pairs are closer than negative pairs (InfoNCE) | ||
| 2. Distribute embeddings more evenly across the embedding space to avoid mode collapse (GOR) | ||
| """ | ||
|
|
||
| def __init__(self, model: SentenceTransformer, similarity_fct=cos_sim, scale=20.0) -> None: | ||
| super().__init__() | ||
| self.model = model | ||
| self.info_nce_loss = MultipleNegativesRankingLoss(model, similarity_fct=similarity_fct, scale=scale) | ||
| self.gor_loss = GlobalOrthogonalRegularizationLoss(model, similarity_fct=similarity_fct) | ||
|
|
||
| def forward( | ||
| self, | ||
| sentence_features: Iterable[dict[str, Tensor]], | ||
| labels: Tensor | None = None, | ||
| ) -> dict[str, Tensor]: | ||
| embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] | ||
| info_nce_loss: dict[str, Tensor] = { | ||
| "info_nce": self.info_nce_loss.compute_loss_from_embeddings(embeddings, labels) | ||
| } | ||
| gor_loss: dict[str, Tensor] = self.gor_loss.compute_loss_from_embeddings(embeddings, labels) | ||
| return {**info_nce_loss, **gor_loss} | ||
|
|
||
|
|
||
| # Model and training parameters | ||
| model_name = "microsoft/mpnet-base" | ||
| num_train_samples = 100_000 | ||
| num_eval_samples = 10_000 | ||
| train_batch_size = 64 | ||
| num_epochs = 1 | ||
|
|
||
| # 1. Load a model to finetune with optional model card data | ||
| logging.info(f"Loading model: {model_name}") | ||
| model = SentenceTransformer( | ||
| model_name, | ||
| model_card_data=SentenceTransformerModelCardData( | ||
| language="en", | ||
| license="apache-2.0", | ||
| model_name="MPNet base trained on GooAQ using InfoNCE + Global Orthogonal Regularization", | ||
| ), | ||
| ) | ||
|
|
||
| # 2. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq | ||
| logging.info("Loading GooAQ dataset") | ||
| dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(num_train_samples)) | ||
| dataset = dataset.add_column("id", range(len(dataset))) | ||
| dataset_dict = dataset.train_test_split(test_size=num_eval_samples, seed=12) | ||
| train_dataset: Dataset = dataset_dict["train"] | ||
| eval_dataset: Dataset = dataset_dict["test"] | ||
| logging.info(f"Train dataset size: {len(train_dataset)}") | ||
| logging.info(f"Eval dataset size: {len(eval_dataset)}") | ||
|
|
||
| # 3. Define the loss function | ||
| loss = InfoNCEGORLoss(model) | ||
|
|
||
| # 4. (Optional) Create an evaluator for use during training | ||
| # We create a small corpus for evaluation to measure retrieval performance | ||
| logging.info("Creating evaluation corpus") | ||
| random.seed(12) | ||
| queries = dict(zip(eval_dataset["id"], eval_dataset["question"])) | ||
| # Use only the answers that correspond to the evaluation queries for a focused evaluation | ||
| corpus = {qid: dataset[qid]["answer"] for qid in queries} | ||
| relevant_docs = {qid: {qid} for qid in eval_dataset["id"]} | ||
|
|
||
| dev_evaluator = InformationRetrievalEvaluator( | ||
| corpus=corpus, | ||
| queries=queries, | ||
| relevant_docs=relevant_docs, | ||
| show_progress_bar=True, | ||
| name="gooaq-dev", | ||
| ) | ||
|
|
||
| # Evaluate the base model before training | ||
| logging.info("Performance before fine-tuning:") | ||
| dev_evaluator(model) | ||
|
|
||
| # 5. Define the training arguments | ||
| short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] | ||
| run_name = f"{short_model_name}-gooaq-infonce-gor" | ||
| args = SentenceTransformerTrainingArguments( | ||
| # Required parameter: | ||
| output_dir=f"models/{run_name}", | ||
| # Optional training parameters: | ||
| num_train_epochs=num_epochs, | ||
| per_device_train_batch_size=train_batch_size, | ||
| per_device_eval_batch_size=train_batch_size, | ||
| learning_rate=2e-5, | ||
| warmup_ratio=0.1, | ||
| fp16=False, # Set to False if you get an error that your GPU can't run on FP16 | ||
| bf16=True, # Set to True if you have a GPU that supports BF16 | ||
| # Use NO_DUPLICATES to ensure each batch has unique samples, which benefits MultipleNegativesRankingLoss | ||
| batch_sampler=BatchSamplers.NO_DUPLICATES, | ||
| # Optional tracking/debugging parameters: | ||
| eval_strategy="steps", | ||
| eval_steps=100, | ||
| save_strategy="steps", | ||
| save_steps=100, | ||
| save_total_limit=2, | ||
| logging_steps=50, | ||
| logging_first_step=True, | ||
| run_name=run_name, # Will be used in W&B if `wandb` is installed | ||
| ) | ||
|
|
||
| # 6. Create a trainer & start training | ||
| trainer = SentenceTransformerTrainer( | ||
| model=model, | ||
| args=args, | ||
| train_dataset=train_dataset.remove_columns("id"), | ||
| eval_dataset=eval_dataset.remove_columns("id"), | ||
| loss=loss, | ||
| evaluator=dev_evaluator, | ||
| ) | ||
| trainer.train() | ||
|
|
||
| # 7. Evaluate the trained model on the test set | ||
| logging.info("Evaluating trained model") | ||
| dev_evaluator(model) | ||
|
|
||
| # 8. Save the trained & evaluated model locally | ||
| final_output_dir = f"models/{run_name}/final" | ||
| model.save_pretrained(final_output_dir) | ||
|
|
||
| # 9. (Optional) save the model to the Hugging Face Hub! | ||
| # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first | ||
| try: | ||
| model.push_to_hub(run_name) | ||
| except Exception: | ||
| logging.error( | ||
| f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " | ||
| f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " | ||
| f"and saving it using `model.push_to_hub('{run_name}')`." | ||
| ) | ||
148 changes: 148 additions & 0 deletions
148
sentence_transformers/losses/GlobalOrthogonalRegularizationLoss.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Iterable | ||
| from typing import Literal | ||
|
|
||
| import torch | ||
| from torch import Tensor, nn | ||
|
|
||
| from sentence_transformers.SentenceTransformer import SentenceTransformer | ||
| from sentence_transformers.util import cos_sim | ||
|
|
||
|
|
||
| class GlobalOrthogonalRegularizationLoss(nn.Module): | ||
| def __init__( | ||
| self, | ||
| model: SentenceTransformer, | ||
| similarity_fct=cos_sim, | ||
| mean_weight: float = 1.0, | ||
| second_moment_weight: float = 1.0, | ||
| aggregation: Literal["mean", "sum"] = "mean", | ||
| ) -> None: | ||
| """ | ||
| Global Orthogonal Regularization (GOR) Loss that encourages embeddings to be well-distributed | ||
| in the embedding space by penalizing high mean similarities and high second moments of similarities | ||
| across unrelated inputs. | ||
|
|
||
| The loss consists of two terms: | ||
|
|
||
| 1. Mean term: Penalizes when the mean similarity across unrelated embeddings is high | ||
| 2. Second moment term: Penalizes when the second moment of similarities is high | ||
|
|
||
| A high second moment indicates that some embeddings have very high similarities, suggesting clustering | ||
| or concentration in certain regions of the embedding space. A low second moment indicates that | ||
| similarities are more uniformly distributed. | ||
|
|
||
| The loss is called independently on each input column (e.g., queries and passages) and combines the results | ||
| using either mean or sum aggregation. This is why the loss can be used on any dataset configuration | ||
| (e.g., single inputs, pairs, triplets, etc.). | ||
|
|
||
| It's recommended to combine this loss with a primary loss function, such as :class:`MultipleNegativesRankingLoss`. | ||
|
|
||
| Args: | ||
| model: SentenceTransformer model | ||
| similarity_fct: Function to compute similarity between embeddings (default: cosine similarity) | ||
| mean_weight: Weight for the mean term loss component (default: 1.0) | ||
| second_moment_weight: Weight for the second moment term loss component (default: 1.0) | ||
| aggregation: How to combine losses across input columns. Either "mean" or "sum" (default: "mean"). | ||
| The EmbeddingGemma paper uses "sum". | ||
|
|
||
| References: | ||
| - For further details, see: https://arxiv.org/abs/1708.06320 or https://arxiv.org/abs/2509.20354. | ||
| The latter paper uses the equivalent of GOR with ``mean_weight=0.0`` and ``aggregation="sum"``. | ||
|
|
||
| Inputs: | ||
tomaarsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| +-------+--------+ | ||
| | Texts | Labels | | ||
| +=======+========+ | ||
| | any | none | | ||
| +-------+--------+ | ||
| """ | ||
| super().__init__() | ||
| self.model = model | ||
tomaarsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.similarity_fct = similarity_fct | ||
| self.mean_weight = mean_weight | ||
| self.second_moment_weight = second_moment_weight | ||
| if aggregation not in ["mean", "sum"]: | ||
| raise ValueError(f"aggregation must be 'mean' or 'sum', got '{aggregation}'") | ||
| self.aggregation = aggregation | ||
|
|
||
| def forward( | ||
| self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor | None = None | ||
| ) -> dict[str, Tensor]: | ||
| embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] | ||
| return self.compute_loss_from_embeddings(embeddings) | ||
|
|
||
| def compute_loss_from_embeddings( | ||
| self, embeddings: list[Tensor], labels: Tensor | None = None | ||
| ) -> dict[str, Tensor]: | ||
| """ | ||
| Compute the GOR loss from pre-computed embeddings. | ||
|
|
||
| Args: | ||
| embeddings: List of embedding tensors, one for each input column (e.g., [queries, passages]) | ||
| labels: Not used, kept for compatibility | ||
|
|
||
| Returns: | ||
| Dictionary containing the weighted mean term and second moment term losses | ||
| """ | ||
| mean_terms, second_moment_terms = zip(*[self.compute_gor(embedding) for embedding in embeddings]) | ||
| results = {} | ||
| if self.mean_weight: | ||
| stacked_mean = torch.stack(mean_terms) | ||
| aggregated_mean = stacked_mean.sum() if self.aggregation == "sum" else stacked_mean.mean() | ||
| results["gor_mean"] = self.mean_weight * aggregated_mean | ||
| if self.second_moment_weight: | ||
tomaarsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| stacked_second_moment = torch.stack(second_moment_terms) | ||
| aggregated_second_moment = ( | ||
| stacked_second_moment.sum() if self.aggregation == "sum" else stacked_second_moment.mean() | ||
| ) | ||
| results["gor_second_moment"] = self.second_moment_weight * aggregated_second_moment | ||
| return results | ||
|
|
||
| def compute_gor(self, embeddings: Tensor) -> tuple[Tensor, Tensor]: | ||
| """ | ||
| Compute the Global Orthogonal Regularization terms for a batch of embeddings. | ||
|
|
||
| The GOR loss encourages embeddings to be well-distributed by: | ||
| 1. Mean term (M_1^2): Penalizes high mean similarity, pushing embeddings apart | ||
| 2. Second moment term (M_2 - 1/d): Penalizes when the second moment exceeds 1/d, encouraging uniform distribution | ||
|
|
||
| Args: | ||
| embeddings: Tensor of shape (batch_size, embedding_dim) | ||
|
|
||
| Returns: | ||
| Tuple of (mean_term, second_moment_term) losses (unweighted) | ||
| """ | ||
| batch_size = embeddings.size(0) | ||
| hidden_dim = embeddings.size(1) | ||
|
|
||
| # Compute pairwise similarity matrix between all embeddings, and exclude self-similarities | ||
| sim_matrix = self.similarity_fct(embeddings, embeddings) | ||
| sim_matrix.fill_diagonal_(0.0) | ||
| num_off_diagonal = batch_size * (batch_size - 1) | ||
|
|
||
| # Mean term: M_1^2 where M_1 = mean of off-diagonal similarities | ||
| # Penalizes high similarities across inputs from the same column (e.g., queries vs other queries) | ||
| mean_term = (sim_matrix.sum() / num_off_diagonal).pow(2) | ||
tomaarsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Second moment term: M_2 - 1/d where M_2 = mean of squared off-diagonal similarities and d is embedding dimension | ||
| # Pushes the second moment close to 1/d, encouraging a more uniform distribution | ||
| second_moment = sim_matrix.pow(2).sum() / num_off_diagonal | ||
tomaarsen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| second_moment_term = torch.relu(second_moment - (1.0 / hidden_dim)) | ||
|
|
||
| return mean_term, second_moment_term | ||
|
|
||
| @property | ||
| def citation(self) -> str: | ||
| return """ | ||
| @misc{zhang2017learningspreadoutlocalfeature, | ||
| title={Learning Spread-out Local Feature Descriptors}, | ||
| author={Xu Zhang and Felix X. Yu and Sanjiv Kumar and Shih-Fu Chang}, | ||
| year={2017}, | ||
| eprint={1708.06320}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.CV}, | ||
| url={https://arxiv.org/abs/1708.06320}, | ||
| } | ||
| """ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.