Spaces:
Running
on
L4
Running
on
L4
File size: 6,495 Bytes
b0f9788 f9167c7 b0f9788 ab5d421 b0f9788 f9167c7 b0f9788 5d028b0 b0f9788 5d028b0 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 1dbe827 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 f9167c7 c2ec774 f9167c7 b0f9788 1dbe827 ab5d421 b0f9788 ab5d421 b0f9788 ab5d421 b0f9788 5d028b0 ab5d421 f9167c7 b0f9788 1dbe827 b0f9788 33f7eb1 b0f9788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import argparse
import os
from pathlib import Path
import torch
from sentence_transformers import (
SparseEncoder,
SparseEncoderTrainingArguments,
SparseEncoderTrainer,
)
from sentence_transformers.sparse_encoder.models import MLMTransformer, SpladePooling
from sentence_transformers.sparse_encoder.losses import (
SparseMultipleNegativesRankingLoss,
SparseCosineSimilarityLoss,
SpladeLoss,
)
from datasets import load_from_disk, DatasetDict
from dotenv import load_dotenv
load_dotenv()
CUDA_AVAILABLE = torch.cuda.is_available()
MPS_AVAILABLE = torch.backends.mps.is_available()
OUTPUT_MODEL_ID = "splade-en-fr-eurobert-v1"
OUTPUT_DIR = "models/splade-en-fr-eurobert-v1"
def train_splade_model(
data_path: str = "data/en-fr-opus",
base_model: str = "EuroBERT/EuroBERT-210m",
output_dir: str = OUTPUT_DIR,
epochs: int = 1,
train_batch_size: int = 16,
learning_rate: float = 2e-5,
warmup_ratio: float = 0.2,
query_reg_weight: float = 5e-5,
doc_reg_weight: float = 3e-5,
) -> None:
"""
Trains a SPLADE model on a properly formatted dataset.
Args:
data_path (str): Path to the directory containing the processed dataset.
base_model (str): The name of the model to use as the base (must be MLM).
output_dir (str): Directory to save training checkpoints and the final model.
epochs (int): Number of training epochs.
train_batch_size (int): Batch size for training.
learning_rate (float): The learning rate for the optimizer.
warmup_ratio (float): The ratio of training steps to use for a linear warmup.
query_reg_weight (float): Sparsity regularization weight for queries.
doc_reg_weight (float): Sparsity regularization weight for documents.
"""
### --- Data ---
dataset_path = Path(data_path)
if not dataset_path.exists():
print(f"Error: Dataset not found at '{data_path}'.")
print("Please run the 'prepare_data.py' script first.")
return
print(f"Loading dataset from '{dataset_path}'...")
training_dataset_dict = load_from_disk(dataset_path)
if not isinstance(training_dataset_dict, DatasetDict) or "train" not in training_dataset_dict:
print("Error: Invalid dataset format. Expected a DatasetDict with a 'train' split.")
return
print("Dataset loaded successfully:")
print(training_dataset_dict)
eval_dataset = training_dataset_dict.get("validation")
if eval_dataset:
print("Validation set found and will be used for evaluation during training.")
### --- Model ---
print(f"Initializing model with base: '{base_model}'")
# Initialize the SPLADE architecture: MLMTransformer + SpladePooling
# The base model must have a Masked Language Modeling (MLM) head.
transformer = MLMTransformer(base_model)
# The SpladePooling layer handles the aggregation and activation.
pooler = SpladePooling(pooling_strategy="max")
model = SparseEncoder(modules=[transformer, pooler])
# or also directly:
# model = SparseEncoder.from_pretrained(base_model)
# MLMTransformer + SpladePooling is the default SPLADE architecture
# Move model to GPU if available
if CUDA_AVAILABLE:
print("Moving model to CUDA device.")
model = model.to("cuda")
elif MPS_AVAILABLE:
print("Moving model to MPS device.")
model = model.to("mps")
else:
print("Warning: No CUDA or MPS device found. Training on CPU.")
### --- Loss ---
# The primary loss is a contrastive loss that learns to differentiate
# between positive and in-batch negative pairs.
# primary_loss = SparseMultipleNegativesRankingLoss(model=model)
primary_loss = SparseCosineSimilarityLoss(model=model)
# SpladeLoss wraps the primary loss and adds the L1 sparsity regularization term.
# This is what makes the embeddings sparse.
splade_loss = SpladeLoss(
model=model,
loss=primary_loss,
query_regularizer_weight=query_reg_weight,
document_regularizer_weight=doc_reg_weight,
)
print(f"Loss function configured with {primary_loss.__class__.__name__}.")
### --- Training ---
# These arguments control every aspect of the training loop.
args = SparseEncoderTrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=train_batch_size,
gradient_accumulation_steps=2,
learning_rate=learning_rate,
warmup_ratio=warmup_ratio,
#fp16=USE_FP16, # base model requires fp32
logging_steps=1000,
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2, # Only keep the last 2 checkpoints
report_to="wandb",
push_to_hub=True,
hub_model_id=f"sofdog/{OUTPUT_MODEL_ID}",
hub_token=os.getenv("HF_TOKEN"),
)
print(f"Training arguments set. Output will be saved to '{output_dir}'.")
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=training_dataset_dict["train"],
eval_dataset=eval_dataset,
loss=splade_loss,
)
print("\n--- Starting Training ---")
trainer.train()
print("--- Done! ---")
### --- Save ---
final_model_path = f"{output_dir}-final"
model.save_pretrained(final_model_path)
print(f"Final model saved to '{final_model_path}'")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a bilingual SPLADE model.")
parser.add_argument("--data_path", type=str, default="data/en-fr-opus", help="Path to the prepared dataset.")
parser.add_argument("--base_model", type=str, default="EuroBERT/EuroBERT-210m", help="Base MLM model.")
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR, help="Output directory for checkpoints.")
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs.")
parser.add_argument("--batch_size", type=int, default=8, help="Training batch size.")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate.")
args = parser.parse_args()
train_splade_model(
data_path=args.data_path,
base_model=args.base_model,
output_dir=args.output_dir,
epochs=args.epochs,
train_batch_size=args.batch_size,
learning_rate=args.lr,
)
|