diff --git a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py index 5e7652f22..3df16fc0e 100644 --- a/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py +++ b/examples/sentence_transformer/training/ms_marco/train_bi-encoder_mnrl.py @@ -14,7 +14,7 @@ With a distilbert-base-uncased model, it should achieve a performance of about 33.79 MRR@10 on the MSMARCO Passages Dev-Corpus Running this script: -python train_bi-encoder-v3.py +python train_bi-encoder_mnrl.py --model_name distilbert-base-uncased """ import argparse @@ -28,9 +28,11 @@ from datetime import datetime import tqdm -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments #### Just some code to print debug information to stdout logging.basicConfig( @@ -243,21 +245,34 @@ def __len__(self): return len(self.queries) -# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training. +# For training the SentenceTransformer model, we need a dataset and a loss function train_dataset = MSMARCODataset(train_queries, corpus=corpus) -train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.MultipleNegativesRankingLoss(model=model) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - epochs=num_epochs, +# Define the training arguments +args_training = SentenceTransformerTrainingArguments( + output_dir=model_save_path, + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, warmup_steps=args.warmup_steps, - use_amp=True, - checkpoint_path=model_save_path, - checkpoint_save_steps=len(train_dataloader), - optimizer_params={"lr": args.lr}, + learning_rate=args.lr, + fp16=True, + bf16=False, + save_strategy="steps", + save_steps=len(train_dataset) // train_batch_size, + save_total_limit=2, + logging_steps=100, ) +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args_training, + train_dataset=train_dataset, + loss=train_loss, +) +trainer.train() + # Save the model -model.save(model_save_path) +final_output_dir = f"{model_save_path}/final" +model.save(final_output_dir)