splade-bert-tiny-nq / train_script.py
tomaarsen's picture
tomaarsen HF Staff
Create train_script.py
7bd9ac6 verified
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
SparseEncoderModelCardData,
)
from sentence_transformers.sparse_encoder.losses import SpladeLoss, SparseMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.models import SpladePooling, MLMTransformer
# 1. Load a model to finetune with 2. (Optional) model card data
mlm_transformer = MLMTransformer("prajjwal1/bert-tiny")
splade_pooling = SpladePooling(pooling_strategy="max", word_embedding_dimension=mlm_transformer.get_sentence_embedding_dimension())
model = SparseEncoder(
modules=[mlm_transformer, splade_pooling],
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="SPLADE BERT-tiny trained on Natural-Questions tuples",
)
)
# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# 4. Define a loss function
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
lambda_query=5e-5,
lambda_corpus=3e-5,
)
# 5. (Optional) Specify training arguments
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir="models/splade-bert-tiny-nq",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=200,
save_total_limit=2,
logging_steps=20,
run_name="splade-bert-tiny-nq", # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 8. Evaluate the model performance again after training
dev_evaluator(model)
# 9. Save the trained model
model.save_pretrained("models/splade-bert-tiny-nq/final")
# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("splade-bert-tiny-nq")