vilegal / train.py
quockhangdev's picture
train script
7c42acc verified
from datasets import load_dataset, Dataset
import pandas as pd
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
def clean_text(x):
if x is None:
return ""
x = str(x).strip()
x = " ".join(x.split())
return x
def build_doc_fast(context):
return f"text: {context}"
def main():
dataset_name = "phamson02/large-vi-legal-queries"
# Load + clean
ds = load_dataset(dataset_name, split="train")
df = ds.to_pandas()
print("Raw shape:", df.shape)
for col in ["domain", "title", "header", "aspect", "context", "query"]:
if col not in df.columns:
df[col] = ""
df[col] = df[col].apply(clean_text)
df = df[(df["query"] != "") & (df["context"] != "")]
df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True)
print("Cleaned rows:", len(df))
train_df = pd.DataFrame(
{
"anchor": df["query"].tolist(),
"positive": [build_doc_fast(context) for context in df["context"].tolist()],
}
)
print(train_df.head())
train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
print(train_dataset[0])
# IMPORTANT: no .to("cuda") here under torchrun / DDP
model = SentenceTransformer(
"google/embeddinggemma-300m",
model_kwargs={
# "torch_dtype": "auto",
# "attn_implementation": "flash_attention_2",
},
)
model.max_seq_length = 512
loss = CachedMultipleNegativesRankingLoss(
model,
mini_batch_size=32,
gather_across_devices=False,
)
task_name = "Retrieval"
training_args = SentenceTransformerTrainingArguments(
prompts=model.prompts[task_name],
torch_compile=False,
output_dir="./embeddinggemma-300m-vilegal",
num_train_epochs=1,
per_device_train_batch_size=1024,
gradient_accumulation_steps=1,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=10,
save_strategy="epoch",
report_to="none",
remove_unused_columns=False,
dataloader_num_workers=8,
dataloader_persistent_workers=True,
# Often helpful for DDP stability/perf with Transformer training:
ddp_find_unused_parameters=False,
)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
# Save only once from the main process
trainer.save_model("./embeddinggemma-300m-vilegal")
if __name__ == "__main__":
main()