Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
With a distilbert-base-uncased model, it should achieve a performance of about 33.79 MRR@10 on the MSMARCO Passages Dev-Corpus

Running this script:
python train_bi-encoder-v3.py
python train_bi-encoder_mnrl.py --model_name distilbert-base-uncased
"""

import argparse
Expand All @@ -28,9 +28,11 @@
from datetime import datetime

import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset

from sentence_transformers import InputExample, LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
Expand Down Expand Up @@ -243,21 +245,34 @@ def __len__(self):
return len(self.queries)


# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
# For training the SentenceTransformer model, we need a dataset and a loss function
train_dataset = MSMARCODataset(train_queries, corpus=corpus)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
# Define the training arguments
args_training = SentenceTransformerTrainingArguments(
output_dir=model_save_path,
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
warmup_steps=args.warmup_steps,
use_amp=True,
checkpoint_path=model_save_path,
checkpoint_save_steps=len(train_dataloader),
optimizer_params={"lr": args.lr},
learning_rate=args.lr,
fp16=True,
bf16=False,
save_strategy="steps",
save_steps=len(train_dataset) // train_batch_size,
save_total_limit=2,
logging_steps=100,
)

# Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args_training,
train_dataset=train_dataset,
loss=train_loss,
)
trainer.train()

# Save the model
model.save(model_save_path)
final_output_dir = f"{model_save_path}/final"
model.save(final_output_dir)
Loading