Skip to content

[feat] Add ADRMSELoss#3690

Open
sky-2002 wants to merge 4 commits intohuggingface:mainfrom
sky-2002:feat/add_adr_mse_loss
Open

[feat] Add ADRMSELoss#3690
sky-2002 wants to merge 4 commits intohuggingface:mainfrom
sky-2002:feat/add_adr_mse_loss

Conversation

@sky-2002
Copy link
Copy Markdown
Contributor

Add ADRMSELoss: Approx Discounted Rank MSE Listwise Ranking Loss

Hey @tomaarsen , I came across this paper which proposes ADRMSELoss, a listwise learning-to-rank loss for cross-encoders based on the Approx Discounted Rank Mean Squared Error (ADR-MSE) objective:

Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs for Passage Re-Ranking
https://arxiv.org/abs/2405.07920

Motivation

Current ranking losses in sentence_transformers.cross_encoder.losses train the model to produce scores that imply the correct ordering.

ADR-MSE provides an alternative listwise objective that directly optimizes rank positions, making it useful for:

  • passage reranking
  • LLM ranking distillation
  • listwise training with multiple documents per query

Method

Given predicted relevance scores:

s_i = f(q, d_i)

ADR-MSE computes a differentiable approximation of the document rank using pairwise sigmoid comparisons:

approx_rank(d_i) = 1 + sum_{j ≠ i} sigmoid(s_j - s_i)

where sigmoid is a smooth approximation of the indicator function:

1 if s_j > s_i
0 otherwise

The final loss minimizes the squared error between the true rank and the predicted (approximate) rank, with an nDCG-style discount:

L = (1/n) * Σ_i [ (1 / log2(i + 1)) * (i − approx_rank(d_i))² ]

Implementation Highlights

  • Differentiable rank approximation via pairwise sigmoid comparisons
  • nDCG-style log discount weighting
  • Support for variable number of documents per query
  • Compatible with existing cross-encoder training pipelines

Example usage:

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.losses import ADRMSELoss

model = CrossEncoder("microsoft/mpnet-base")
loss = ADRMSELoss(model)

Reference

Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs for Passage Re-Ranking
https://arxiv.org/abs/2405.07920

Tests

I did not see a directory for cross-encoder loss tests, do we want to add tests for these losses?

@tomaarsen
Copy link
Copy Markdown
Member

Hello!

I had the chance to run a script like https://github.com/huggingface/sentence-transformers/blob/main/examples/cross_encoder/training/ms_marco/training_ms_marco_ranknet.py except with ADRMSELoss last night. Sadly, the results weren't great:

So I think there is likely an issue in the current implementation. Although it might be smart to rerun the baseline to make sure there's no other changes. Here is the training script I used:

Click to expand
from __future__ import annotations

import logging
import traceback
from datetime import datetime

import torch
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import ADRMSELoss
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments


def main():
    model_name = "microsoft/MiniLM-L12-H384-uncased"

    # Set the log level to INFO to get more information
    logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
    )
    # train_batch_size and eval_batch_size inform the size of the batches, while mini_batch_size is used by the loss
    # to subdivide the batch into smaller parts. This mini_batch_size largely informs the training speed and memory usage.
    # Keep in mind that the loss does not process `train_batch_size` pairs, but `train_batch_size * num_docs` pairs.
    train_batch_size = 16
    eval_batch_size = 16
    mini_batch_size = 16
    num_epochs = 1
    max_docs = None

    dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # 1. Define our CrossEncoder model
    # Set the seed so the new classifier weights are identical in subsequent runs
    torch.manual_seed(12)
    model = CrossEncoder(model_name, num_labels=1)
    print("Model max length:", model.max_length)
    print("Model num labels:", model.num_labels)

    # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco
    logging.info("Read train dataset")
    dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")

    def listwise_mapper(batch, max_docs: int | None = 10):
        processed_queries = []
        processed_docs = []
        processed_labels = []

        for query, passages_info in zip(batch["query"], batch["passages"]):
            # Extract passages and labels
            passages = passages_info["passage_text"]
            labels = passages_info["is_selected"]

            # Pair passages with labels and sort descending by label (positives first)
            paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)

            # Separate back to passages and labels
            sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])

            # Filter queries without any positive labels
            if max(sorted_labels) < 1.0:
                continue

            # Truncate to max_docs
            if max_docs is not None:
                sorted_passages = list(sorted_passages[:max_docs])
                sorted_labels = list(sorted_labels[:max_docs])

            processed_queries.append(query)
            processed_docs.append(sorted_passages)
            processed_labels.append(sorted_labels)

        return {
            "query": processed_queries,
            "docs": processed_docs,
            "labels": processed_labels,
        }

    # Create a dataset with a "query" column with strings, a "docs" column with lists of strings,
    # and a "labels" column with lists of floats
    dataset = dataset.map(
        lambda batch: listwise_mapper(batch=batch, max_docs=max_docs),
        batched=True,
        remove_columns=dataset.column_names,
        desc="Processing listwise samples",
    )

    dataset = dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset["train"]
    eval_dataset = dataset["test"]
    logging.info(train_dataset)

    # 3. Define our training loss
    loss = ADRMSELoss(
        model=model,
        mini_batch_size=mini_batch_size,
    )

    # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
    evaluator = CrossEncoderNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=eval_batch_size)
    evaluator(model)

    # 5. Define the training arguments
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    run_name = f"reranker-msmarco-v1.1-{short_model_name}-adrmse-pr3690"
    args = CrossEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}_{dt}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        load_best_model_at_end=True,
        metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
        seed=12,
    )

    # 6. Create the trainer & start training
    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model, useful to include these in the model card
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}_{dt}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()

When looking for my baselines (I have a bunch of models under https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-... when I compared training losses), I also stumbled upon this: https://huggingface.co/tomaarsen/reranker-msmarco-v1.1-MiniLM-L12-H384-uncased-adrmse

It scores 0.5160 NDCG@10 (+0.0607 compared to BM25), similar to ranknet, and according to the README it was trained with Loss: ApproxDiscountedRankMSE. Sadly, I can't find a PR with that text, nor a local git stash, commit, or archived file. Normally I keep pretty good track of my previous experiments/uncommitted work, but this is a bit of a mystery for me.

  • Tom Aarsen

@sky-2002
Copy link
Copy Markdown
Contributor Author

Hey @tomaarsen , thanks for pointing out, you were right, there was a bug, I think I found it, the score computation was basically reversed. Fixed it and tested it on a similar script.

# Before (buggy): s_i - s_j → high scores get high (bad) ranks
score_diffs = scores.unsqueeze(2) - scores.unsqueeze(1)

# After (fixed): s_j - s_i → high scores get low (good) ranks
score_diffs = scores.unsqueeze(1) - scores.unsqueeze(2)

I trained a model with the fixed loss on msmarco, 58K samples (using similar script) and evaluated on NanoBIER msmarco subset and got an NDCG@10 of 53.8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants