quockhangdev commited on
Commit
7c42acc
·
verified ·
1 Parent(s): 29b079d

train script

Browse files
Files changed (1) hide show
  1. train.py +105 -0
train.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, Dataset
2
+ import pandas as pd
3
+ from sentence_transformers import (
4
+ SentenceTransformer,
5
+ SentenceTransformerTrainer,
6
+ SentenceTransformerTrainingArguments,
7
+ )
8
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
9
+
10
+
11
+ def clean_text(x):
12
+ if x is None:
13
+ return ""
14
+ x = str(x).strip()
15
+ x = " ".join(x.split())
16
+ return x
17
+
18
+
19
+ def build_doc_fast(context):
20
+ return f"text: {context}"
21
+
22
+
23
+ def main():
24
+ dataset_name = "phamson02/large-vi-legal-queries"
25
+
26
+ # Load + clean
27
+ ds = load_dataset(dataset_name, split="train")
28
+ df = ds.to_pandas()
29
+
30
+ print("Raw shape:", df.shape)
31
+
32
+ for col in ["domain", "title", "header", "aspect", "context", "query"]:
33
+ if col not in df.columns:
34
+ df[col] = ""
35
+ df[col] = df[col].apply(clean_text)
36
+
37
+ df = df[(df["query"] != "") & (df["context"] != "")]
38
+ df = df.drop_duplicates(subset=["query", "context"]).reset_index(drop=True)
39
+
40
+ print("Cleaned rows:", len(df))
41
+
42
+ train_df = pd.DataFrame(
43
+ {
44
+ "anchor": df["query"].tolist(),
45
+ "positive": [build_doc_fast(context) for context in df["context"].tolist()],
46
+ }
47
+ )
48
+
49
+ print(train_df.head())
50
+
51
+ train_dataset = Dataset.from_pandas(train_df, preserve_index=False)
52
+ print(train_dataset[0])
53
+
54
+ # IMPORTANT: no .to("cuda") here under torchrun / DDP
55
+ model = SentenceTransformer(
56
+ "google/embeddinggemma-300m",
57
+ model_kwargs={
58
+ # "torch_dtype": "auto",
59
+ # "attn_implementation": "flash_attention_2",
60
+ },
61
+ )
62
+ model.max_seq_length = 512
63
+
64
+ loss = CachedMultipleNegativesRankingLoss(
65
+ model,
66
+ mini_batch_size=32,
67
+ gather_across_devices=False,
68
+ )
69
+
70
+ task_name = "Retrieval"
71
+ training_args = SentenceTransformerTrainingArguments(
72
+ prompts=model.prompts[task_name],
73
+ torch_compile=False,
74
+ output_dir="./embeddinggemma-300m-vilegal",
75
+ num_train_epochs=1,
76
+ per_device_train_batch_size=1024,
77
+ gradient_accumulation_steps=1,
78
+ learning_rate=2e-5,
79
+ warmup_ratio=0.1,
80
+ bf16=True,
81
+ logging_steps=10,
82
+ save_strategy="epoch",
83
+ report_to="none",
84
+ remove_unused_columns=False,
85
+ dataloader_num_workers=8,
86
+ dataloader_persistent_workers=True,
87
+ # Often helpful for DDP stability/perf with Transformer training:
88
+ ddp_find_unused_parameters=False,
89
+ )
90
+
91
+ trainer = SentenceTransformerTrainer(
92
+ model=model,
93
+ args=training_args,
94
+ train_dataset=train_dataset,
95
+ loss=loss,
96
+ )
97
+
98
+ trainer.train()
99
+
100
+ # Save only once from the main process
101
+ trainer.save_model("./embeddinggemma-300m-vilegal")
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()