diff --git a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_margin-mse.py b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_margin-mse.py index c2b9f219b..d41729ce8 100644 --- a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_margin-mse.py +++ b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_margin-mse.py @@ -1,254 +1,158 @@ -import argparse -import gzip -import json import logging -import os -import pickle import random -import sys -import tarfile -from datetime import datetime -from shutil import copyfile import tqdm -from torch.utils.data import DataLoader, Dataset +from datasets import Dataset, load_dataset -from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util +from sentence_transformers import SentenceTransformer +from sentence_transformers.losses import MarginMSELoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -parser = argparse.ArgumentParser() -parser.add_argument("--train_batch_size", default=64, type=int) -parser.add_argument("--max_seq_length", default=300, type=int) -parser.add_argument("--model_name", required=True) -parser.add_argument("--max_passages", default=0, type=int) -parser.add_argument("--epochs", default=30, type=int) -parser.add_argument("--pooling", default="mean") -parser.add_argument( - "--negs_to_use", - default=None, - help="From which systems should negatives be used? Multiple systems separated by comma. None = all", -) -parser.add_argument("--warmup_steps", default=1000, type=int) -parser.add_argument("--lr", default=2e-5, type=float) -parser.add_argument("--num_negs_per_system", default=5, type=int) -parser.add_argument("--use_pre_trained_model", default=False, action="store_true") -parser.add_argument("--use_all_queries", default=False, action="store_true") -args = parser.parse_args() - -logging.info(str(args)) +# Just some code to print debug information to stdout +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) # Quiet httpx logs -# The model we want to fine-tune -train_batch_size = ( - args.train_batch_size -) # Increasing the train batch size improves the model performance, but requires more GPU memory -model_name = args.model_name -max_passages = args.max_passages -max_seq_length = args.max_seq_length # Max length for passages. Increasing it, requires more GPU memory +train_batch_size = 64 +max_seq_length = 300 # Max length for passages. Increasing it, requires more GPU memory +model_name = "microsoft/mpnet-base" +num_epochs = 1 +max_steps = -1 +lr = 2e-5 -num_negs_per_system = ( - args.num_negs_per_system -) # We used different systems to mine hard negatives. Number of hard negatives to add from each system -num_epochs = args.epochs # Number of epochs we want to train +# We used different systems to mine hard negatives. Number of hard negatives to add from each system +num_negs_per_system = 5 +num_negatives = 5 # Load our embedding model -if args.use_pre_trained_model: - logging.info("use pretrained SBERT model") - model = SentenceTransformer(model_name) - model.max_seq_length = max_seq_length -else: - logging.info("Create new SBERT model") - word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) - model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -model_save_path = f"output/train_bi-encoder-margin_mse-{model_name.replace('/', '-')}-batch_size_{train_batch_size}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - - -# Write self to path -os.makedirs(model_save_path, exist_ok=True) - -train_script_path = os.path.join(model_save_path, "train_script.py") -copyfile(__file__, train_script_path) -with open(train_script_path, "a") as fOut: - fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) - - -### Now we read the MS Marco dataset -data_folder = "msmarco-data" - -#### Read the corpus files, that contain all the passages. Store them in the corpus dict -corpus = {} # dict in the format: passage_id -> passage. Stores all existent passages -collection_filepath = os.path.join(data_folder, "collection.tsv") -if not os.path.exists(collection_filepath): - tar_filepath = os.path.join(data_folder, "collection.tar.gz") - if not os.path.exists(tar_filepath): - logging.info("Download collection.tar.gz") - util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz", tar_filepath) - - with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) - -logging.info("Read corpus: collection.tsv") -with open(collection_filepath, encoding="utf8") as fIn: - for line in fIn: - pid, passage = line.strip().split("\t") - pid = int(pid) - corpus[pid] = passage - - -### Read the train queries, store in queries dict -queries = {} # dict in the format: query_id -> query. Stores all training queries -queries_filepath = os.path.join(data_folder, "queries.train.tsv") -if not os.path.exists(queries_filepath): - tar_filepath = os.path.join(data_folder, "queries.tar.gz") - if not os.path.exists(tar_filepath): - logging.info("Download queries.tar.gz") - util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz", tar_filepath) - - with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) - - -with open(queries_filepath, encoding="utf8") as fIn: - for line in fIn: - qid, query = line.strip().split("\t") - qid = int(qid) - queries[qid] = query - - -# Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid) -# to the CrossEncoder score computed by the cross-encoder/ms-marco-MiniLM-L6-v2 model -ce_scores_file = os.path.join(data_folder, "cross-encoder-ms-marco-MiniLM-L6-v2-scores.pkl.gz") -if not os.path.exists(ce_scores_file): - logging.info("Download cross-encoder scores file") - util.http_get( - "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz", - ce_scores_file, - ) - +logging.info("Using pretrained SBERT model") +model = SentenceTransformer(model_name) +model.max_seq_length = max_seq_length + +# Map PID -> text +corpus = load_dataset("sentence-transformers/msmarco-corpus", "passage", split="train") +corpus_dict = dict(zip(corpus["pid"], corpus["text"])) + +# Map QID -> query text +queries = load_dataset("sentence-transformers/msmarco-corpus", "query", split="train") +query_dict = dict(zip(queries["qid"], queries["text"])) + +# Map QID -> {PID: CE score} +scores = load_dataset("sentence-transformers/msmarco-scores-ms-marco-MiniLM-L6-v2", "list", split="train") +ce_scores = { + qid: dict(zip(cids, sc)) for qid, cids, sc in zip(scores["query_id"], scores["corpus_id"], scores["score"]) +} logging.info("Load CrossEncoder scores dict") -with gzip.open(ce_scores_file, "rb") as fIn: - ce_scores = pickle.load(fIn) - -# As training data we use hard-negatives that have been mined using various systems -hard_negatives_filepath = os.path.join(data_folder, "msmarco-hard-negatives.jsonl.gz") -if not os.path.exists(hard_negatives_filepath): - logging.info("Download cross-encoder scores file") - util.http_get( - "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz", - hard_negatives_filepath, - ) - -logging.info("Read hard negatives train file") -train_queries = {} -negs_to_use = None -with gzip.open(hard_negatives_filepath, "rt") as fIn: - for line in tqdm.tqdm(fIn): - if max_passages > 0 and len(train_queries) >= max_passages: - break - data = json.loads(line) - - # Get the positive passage ids - pos_pids = data["pos"] - - # Get the hard negatives - neg_pids = set() - if negs_to_use is None: - if args.negs_to_use is not None: # Use specific system for negatives - negs_to_use = args.negs_to_use.split(",") - else: # Use all systems - negs_to_use = list(data["neg"].keys()) - logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) - - for system_name in negs_to_use: - if system_name not in data["neg"]: +# Datasets with 50 hard negatives mined per query using different models +SYSTEMS = { + "bm25": "sentence-transformers/msmarco-bm25", + "msmarco-distilbert-base-tas-b": "sentence-transformers/msmarco-msmarco-distilbert-base-tas-b", + "msmarco-distilbert-base-v3": "sentence-transformers/msmarco-msmarco-distilbert-base-v3", + "msmarco-MiniLM-L-6-v3": "sentence-transformers/msmarco-msmarco-MiniLM-L6-v3", + "distilbert-margin_mse-cls-dot-v2": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v2", + "distilbert-margin_mse-cls-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v1", + "distilbert-margin_mse-mean-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mean-dot-v1", + "mpnet-margin_mse-mean-v1": "sentence-transformers/msmarco-mpnet-margin-mse-mean-v1", + "co-condenser-margin_mse-cls-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-cls-v1", + "distilbert-margin_mse-mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mnrl-mean-v1", + "distilbert-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v1", + "distilbert-margin_mse-sym_mnrl-mean-v2": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v2", + "co-condenser-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", +} + +train_data = {} +for system_key, repo_id in SYSTEMS.items(): + print(f"Loading {system_key}...") + dataset = load_dataset(repo_id, "triplet-50-ids", split="train") + + for row in tqdm.tqdm(dataset, desc=f"Processing {system_key}"): + qid = row.pop("query") + pos_pid = row.pop("positive") + neg_pids = list(row.values()) # All remaining columns are negatives + existing_neg_pids = set(train_data[qid]["neg_pids"]) if qid in train_data else set() + pos_ce_score = ce_scores[qid][pos_pid] + valid_neg_pids = [] + valid_neg_labels = [] + + for neg_pid in neg_pids: + if neg_pid in existing_neg_pids or neg_pid not in ce_scores[qid]: continue - system_negs = data["neg"][system_name] - negs_added = 0 - for pid in system_negs: - if pid not in neg_pids: - neg_pids.add(pid) - negs_added += 1 - if negs_added >= num_negs_per_system: - break - - if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): - train_queries[data["qid"]] = { - "qid": data["qid"], - "query": queries[data["qid"]], - "pos": pos_pids, - "neg": neg_pids, - } - -logging.info(f"Train queries: {len(train_queries)}") - - -# We create a custom MSMARCO dataset that returns triplets (query, positive, negative) -# on-the-fly based on the information from the mined-hard-negatives jsonl file. -class MSMARCODataset(Dataset): - def __init__(self, queries, corpus, ce_scores): - self.queries = queries - self.queries_ids = list(queries.keys()) - self.corpus = corpus - self.ce_scores = ce_scores - - for qid in self.queries: - self.queries[qid]["pos"] = list(self.queries[qid]["pos"]) - self.queries[qid]["neg"] = list(self.queries[qid]["neg"]) - random.shuffle(self.queries[qid]["neg"]) - - def __getitem__(self, item): - query = self.queries[self.queries_ids[item]] - query_text = query["query"] - qid = query["qid"] - - if len(query["pos"]) > 0: - pos_id = query["pos"].pop(0) # Pop positive and add at end - pos_text = self.corpus[pos_id] - query["pos"].append(pos_id) - else: # We only have negatives, use two negs - pos_id = query["neg"].pop(0) # Pop negative and add at end - pos_text = self.corpus[pos_id] - query["neg"].append(pos_id) - - # Get a negative passage - neg_id = query["neg"].pop(0) # Pop negative and add at end - neg_text = self.corpus[neg_id] - query["neg"].append(neg_id) - - pos_score = self.ce_scores[qid][pos_id] - neg_score = self.ce_scores[qid][neg_id] - - return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score - neg_score) - - def __len__(self): - return len(self.queries) - - -# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training. -train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores) -train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True) -train_loss = losses.MarginMSELoss(model=model) + valid_neg_pids.append(neg_pid) + valid_neg_labels.append(pos_ce_score - ce_scores[qid][neg_pid]) + existing_neg_pids.add(neg_pid) + if len(valid_neg_pids) >= num_negs_per_system: + break + if qid not in train_data: + train_data[qid] = {"qid": qid, "pid": pos_pid, "neg_pids": valid_neg_pids, "neg_labels": valid_neg_labels} + else: + train_data[qid]["neg_pids"].extend(valid_neg_pids) + train_data[qid]["neg_labels"].extend(valid_neg_labels) + +train_data = {qid: data for qid, data in train_data.items() if data["neg_pids"]} +logging.info(f"Kept {len(train_data)} queries with negatives") + +train_dataset = Dataset.from_list(list(train_data.values())) + + +def ids_to_text_transform(batch): + sampled = [ + random.sample(list(zip(neg_pids, neg_labels)), num_negatives) + for neg_pids, neg_labels in zip(batch["neg_pids"], batch["neg_labels"]) + ] + neg_pid_lists, label_lists = zip(*[zip(*s) for s in sampled]) + return { + "anchor": [query_dict[qid] for qid in batch["qid"]], + "positive": [corpus_dict[pid] for pid in batch["pid"]], + **{ + f"negative_{idx}": [corpus_dict[pid] for pid in neg_ids] for idx, neg_ids in enumerate(zip(*neg_pid_lists)) + }, + "label": list(label_lists), + } + + +train_dataset.set_transform(ids_to_text_transform) + +# Loss function +loss = MarginMSELoss(model) + +# Prepare training arguments +short_model_name = model_name.split("/")[-1] if "/" in model_name else model_name +run_name = f"{short_model_name}-msmarco-margin-mse" +args = SentenceTransformerTrainingArguments( + output_dir=f"output/{run_name}", + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + warmup_ratio=0.1, + learning_rate=lr, + max_steps=max_steps, + save_strategy="steps", + save_steps=0.1, + logging_steps=0.01, + batch_sampler=BatchSamplers.NO_DUPLICATES, +) # Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=num_epochs, - warmup_steps=args.warmup_steps, - use_amp=True, - checkpoint_path=model_save_path, - checkpoint_save_steps=10000, - optimizer_params={"lr": args.lr}, +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=loss, ) - -# Train latest model -model.save(model_save_path) +trainer.train() + +final_model_path = f"output/{run_name}/final" +model.save_pretrained(final_model_path) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +try: + model.push_to_hub(f"{run_name}") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\nTo upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_model_path!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + ) diff --git a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py index 5e7652f22..2ad1ee223 100644 --- a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py +++ b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py @@ -14,250 +14,161 @@ With a distilbert-base-uncased model, it should achieve a performance of about 33.79 MRR@10 on the MSMARCO Passages Dev-Corpus Running this script: -python train_bi-encoder-v3.py +python train_bi-encoder_mnrl.py """ -import argparse -import gzip -import json import logging -import os -import pickle import random -import tarfile -from datetime import datetime import tqdm -from torch.utils.data import DataLoader, Dataset +from datasets import Dataset, load_dataset -from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util +from sentence_transformers import SentenceTransformer +from sentence_transformers.losses import CachedMultipleNegativesRankingLoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -parser = argparse.ArgumentParser() -parser.add_argument("--train_batch_size", default=64, type=int) -parser.add_argument("--max_seq_length", default=300, type=int) -parser.add_argument("--model_name", required=True) -parser.add_argument("--max_passages", default=0, type=int) -parser.add_argument("--epochs", default=10, type=int) -parser.add_argument("--pooling", default="mean") -parser.add_argument( - "--negs_to_use", - default=None, - help="From which systems should negatives be used? Multiple systems separated by comma. None = all", -) -parser.add_argument("--warmup_steps", default=1000, type=int) -parser.add_argument("--lr", default=2e-5, type=float) -parser.add_argument("--num_negs_per_system", default=5, type=int) -parser.add_argument("--use_pre_trained_model", default=False, action="store_true") -parser.add_argument("--use_all_queries", default=False, action="store_true") -parser.add_argument("--ce_score_margin", default=3.0, type=float) -args = parser.parse_args() - -print(args) - -# The model we want to fine-tune -model_name = args.model_name - -train_batch_size = ( - args.train_batch_size -) # Increasing the train batch size improves the model performance, but requires more GPU memory -max_seq_length = args.max_seq_length # Max length for passages. Increasing it, requires more GPU memory -ce_score_margin = args.ce_score_margin # Margin for the CrossEncoder score between negative and positive passages -num_negs_per_system = ( - args.num_negs_per_system -) # We used different systems to mine hard negatives. Number of hard negatives to add from each system -num_epochs = args.epochs # Number of epochs we want to train +# Just some code to print debug information to stdout +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) # Quiet httpx logs -# Load our embedding model -if args.use_pre_trained_model: - logging.info("use pretrained SBERT model") - model = SentenceTransformer(model_name) - model.max_seq_length = max_seq_length -else: - logging.info("Create new SBERT model") - word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) - model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -model_save_path = "output/train_bi-encoder-mnrl-{}-margin_{:.1f}-{}".format( - model_name.replace("/", "-"), ce_score_margin, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) +train_batch_size = 512 +train_mini_batch_size = 16 +max_seq_length = 300 # Max length for passages. Increasing it, requires more GPU memory +model_name = "microsoft/mpnet-base" +max_passages = 0 +num_epochs = 1 +max_steps = -1 +negs_to_use = None +lr = 2e-5 -### Now we read the MS Marco dataset -data_folder = "msmarco-data" - -#### Read the corpus files, that contain all the passages. Store them in the corpus dict -corpus = {} # dict in the format: passage_id -> passage. Stores all existent passages -collection_filepath = os.path.join(data_folder, "collection.tsv") -if not os.path.exists(collection_filepath): - tar_filepath = os.path.join(data_folder, "collection.tar.gz") - if not os.path.exists(tar_filepath): - logging.info("Download collection.tar.gz") - util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz", tar_filepath) - - with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) - -logging.info("Read corpus: collection.tsv") -with open(collection_filepath, encoding="utf8") as fIn: - for line in fIn: - pid, passage = line.strip().split("\t") - pid = int(pid) - corpus[pid] = passage - - -### Read the train queries, store in queries dict -queries = {} # dict in the format: query_id -> query. Stores all training queries -queries_filepath = os.path.join(data_folder, "queries.train.tsv") -if not os.path.exists(queries_filepath): - tar_filepath = os.path.join(data_folder, "queries.tar.gz") - if not os.path.exists(tar_filepath): - logging.info("Download queries.tar.gz") - util.http_get("https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz", tar_filepath) - - with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) - - -with open(queries_filepath, encoding="utf8") as fIn: - for line in fIn: - qid, query = line.strip().split("\t") - qid = int(qid) - queries[qid] = query - - -# Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid) -# to the CrossEncoder score computed by the cross-encoder/ms-marco-MiniLM-L6-v2 model -ce_scores_file = os.path.join(data_folder, "cross-encoder-ms-marco-MiniLM-L6-v2-scores.pkl.gz") -if not os.path.exists(ce_scores_file): - logging.info("Download cross-encoder scores file") - util.http_get( - "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz", - ce_scores_file, - ) +# We used different systems to mine hard negatives. Number of hard negatives to load from each system +num_negs_per_system = 5 +# Number of negatives to use per query. 1 means triplets, and more means training with multiple negatives per query +num_negatives = 16 +# Margin in CE score between the positive and negative passage. If the negative passage has a CE score that is higher than the positive passage minus this margin, we don't use it for training, as it might be a false negative. +ce_score_margin = 3.0 +# Load our embedding model +logging.info("Using pretrained SBERT model") +model = SentenceTransformer(model_name) +model.max_seq_length = max_seq_length + +# Map PID -> text +corpus = load_dataset("sentence-transformers/msmarco-corpus", "passage", split="train") +corpus_dict = dict(zip(corpus["pid"], corpus["text"])) + +# Map QID -> query text +queries = load_dataset("sentence-transformers/msmarco-corpus", "query", split="train") +query_dict = dict(zip(queries["qid"], queries["text"])) + +# Map QID -> {PID: CE score} +scores = load_dataset("sentence-transformers/msmarco-scores-ms-marco-MiniLM-L6-v2", "list", split="train") +ce_scores = { + qid: dict(zip(cids, sc)) for qid, cids, sc in zip(scores["query_id"], scores["corpus_id"], scores["score"]) +} logging.info("Load CrossEncoder scores dict") -with gzip.open(ce_scores_file, "rb") as fIn: - ce_scores = pickle.load(fIn) - -# As training data we use hard-negatives that have been mined using various systems -hard_negatives_filepath = os.path.join(data_folder, "msmarco-hard-negatives.jsonl.gz") -if not os.path.exists(hard_negatives_filepath): - logging.info("Download cross-encoder scores file") - util.http_get( - "https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz", - hard_negatives_filepath, - ) - -logging.info("Read hard negatives train file") -train_queries = {} -negs_to_use = None -with gzip.open(hard_negatives_filepath, "rt") as fIn: - for line in tqdm.tqdm(fIn): - data = json.loads(line) - - # Get the positive passage ids - qid = data["qid"] - pos_pids = data["pos"] - - if len(pos_pids) == 0: # Skip entries without positives passages - continue - - pos_min_ce_score = min([ce_scores[qid][pid] for pid in data["pos"]]) - ce_score_threshold = pos_min_ce_score - ce_score_margin - - # Get the hard negatives - neg_pids = set() - if negs_to_use is None: - if args.negs_to_use is not None: # Use specific system for negatives - negs_to_use = args.negs_to_use.split(",") - else: # Use all systems - negs_to_use = list(data["neg"].keys()) - logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) - - for system_name in negs_to_use: - if system_name not in data["neg"]: +# Datasets with 50 hard negatives mined per query using different models +SYSTEMS = { + "bm25": "sentence-transformers/msmarco-bm25", + "msmarco-distilbert-base-tas-b": "sentence-transformers/msmarco-msmarco-distilbert-base-tas-b", + "msmarco-distilbert-base-v3": "sentence-transformers/msmarco-msmarco-distilbert-base-v3", + "msmarco-MiniLM-L-6-v3": "sentence-transformers/msmarco-msmarco-MiniLM-L6-v3", + "distilbert-margin_mse-cls-dot-v2": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v2", + "distilbert-margin_mse-cls-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-cls-dot-v1", + "distilbert-margin_mse-mean-dot-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mean-dot-v1", + "mpnet-margin_mse-mean-v1": "sentence-transformers/msmarco-mpnet-margin-mse-mean-v1", + "co-condenser-margin_mse-cls-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-cls-v1", + "distilbert-margin_mse-mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-mnrl-mean-v1", + "distilbert-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v1", + "distilbert-margin_mse-sym_mnrl-mean-v2": "sentence-transformers/msmarco-distilbert-margin-mse-sym-mnrl-mean-v2", + "co-condenser-margin_mse-sym_mnrl-mean-v1": "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", +} + +train_data = {} +for system_key, repo_id in SYSTEMS.items(): + print(f"Loading {system_key}...") + dataset = load_dataset(repo_id, "triplet-50-ids", split="train") + + for row in tqdm.tqdm(dataset, desc=f"Processing {system_key}"): + qid = row.pop("query") + pos_pid = row.pop("positive") + neg_pids = list(row.values()) # All remaining columns are negatives + existing_neg_pids = train_data[qid]["neg_pids"] if qid in train_data else set() + valid_neg_pids = set() + pos_ce_score = ce_scores[qid][pos_pid] + + for neg_pid in neg_pids: + if neg_pid in existing_neg_pids or neg_pid not in ce_scores[qid]: continue - system_negs = data["neg"][system_name] - negs_added = 0 - for pid in system_negs: - if ce_scores[qid][pid] > ce_score_threshold: - continue - - if pid not in neg_pids: - neg_pids.add(pid) - negs_added += 1 - if negs_added >= num_negs_per_system: - break - - if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): - train_queries[data["qid"]] = { - "qid": data["qid"], - "query": queries[data["qid"]], - "pos": pos_pids, - "neg": neg_pids, - } - -del ce_scores - -logging.info(f"Train queries: {len(train_queries)}") - - -# We create a custom MSMARCO dataset that returns triplets (query, positive, negative) -# on-the-fly based on the information from the mined-hard-negatives jsonl file. -class MSMARCODataset(Dataset): - def __init__(self, queries, corpus): - self.queries = queries - self.queries_ids = list(queries.keys()) - self.corpus = corpus - - for qid in self.queries: - self.queries[qid]["pos"] = list(self.queries[qid]["pos"]) - self.queries[qid]["neg"] = list(self.queries[qid]["neg"]) - random.shuffle(self.queries[qid]["neg"]) - - def __getitem__(self, item): - query = self.queries[self.queries_ids[item]] - query_text = query["query"] - - pos_id = query["pos"].pop(0) # Pop positive and add at end - pos_text = self.corpus[pos_id] - query["pos"].append(pos_id) - - neg_id = query["neg"].pop(0) # Pop negative and add at end - neg_text = self.corpus[neg_id] - query["neg"].append(neg_id) - - return InputExample(texts=[query_text, pos_text, neg_text]) - - def __len__(self): - return len(self.queries) - - -# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training. -train_dataset = MSMARCODataset(train_queries, corpus=corpus) -train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) -train_loss = losses.MultipleNegativesRankingLoss(model=model) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=num_epochs, - warmup_steps=args.warmup_steps, - use_amp=True, - checkpoint_path=model_save_path, - checkpoint_save_steps=len(train_dataloader), - optimizer_params={"lr": args.lr}, + neg_ce_score = ce_scores[qid][neg_pid] + + if neg_ce_score < pos_ce_score - ce_score_margin: + valid_neg_pids.add(neg_pid) + existing_neg_pids.add(neg_pid) + if len(valid_neg_pids) >= num_negs_per_system: + break + if qid not in train_data: + train_data[qid] = {"qid": qid, "pid": pos_pid, "neg_pids": valid_neg_pids} + else: + train_data[qid]["neg_pids"].update(valid_neg_pids) + +train_data = {qid: data for qid, data in train_data.items() if len(data["neg_pids"]) >= num_negatives} +logging.info(f"Kept {len(train_data)} queries with >= {num_negatives} negatives") + +train_dataset = Dataset.from_list(list(train_data.values())) + + +def ids_to_text_transform(batch): + negative_ids = [random.sample(neg_pids, num_negatives) for neg_pids in batch["neg_pids"]] + return { + "query": [query_dict[qid] for qid in batch["qid"]], + "positive": [corpus_dict[pid] for pid in batch["pid"]], + **{f"negative_{idx}": [corpus_dict[pid] for pid in neg_ids] for idx, neg_ids in enumerate(zip(*negative_ids))}, + } + + +train_dataset.set_transform(ids_to_text_transform) + +loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=train_mini_batch_size) + +# Prepare training arguments +short_model_name = model_name.split("/")[-1] if "/" in model_name else model_name +run_name = f"{short_model_name}-msmarco-mnrl" +args = SentenceTransformerTrainingArguments( + output_dir=f"output/{run_name}", + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + warmup_ratio=0.1, + learning_rate=lr, + max_steps=max_steps, + save_strategy="steps", + save_steps=0.001, + logging_steps=0.01, + batch_sampler=BatchSamplers.NO_DUPLICATES, ) -# Save the model -model.save(model_save_path) +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + loss=loss, +) +trainer.train() + +final_model_path = f"output/{run_name}/final" +model.save_pretrained(final_model_path) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +try: + model.push_to_hub(f"{run_name}") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\nTo upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_model_path!r})` " + f"and saving it using `model.push_to_hub('{run_name}')`." + )