tomaarsen HF staff commited on
Commit
c7c84c3
·
verified ·
1 Parent(s): 41dadcc

Create train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +159 -0
train_script.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import traceback
3
+
4
+ from datasets import load_dataset
5
+
6
+ from sentence_transformers.cross_encoder import CrossEncoder
7
+ from sentence_transformers.cross_encoder.evaluation.CENanoBEIREvaluator import (
8
+ CENanoBEIREvaluator,
9
+ )
10
+ from sentence_transformers.cross_encoder.losses import ListNetLoss
11
+ from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
12
+ from sentence_transformers.cross_encoder.training_args import (
13
+ CrossEncoderTrainingArguments,
14
+ )
15
+
16
+
17
+ def main():
18
+ model_name = "microsoft/MiniLM-L12-H384-uncased"
19
+
20
+ # Set the log level to INFO to get more information
21
+ logging.basicConfig(
22
+ format="%(asctime)s - %(message)s",
23
+ datefmt="%Y-%m-%d %H:%M:%S",
24
+ level=logging.INFO,
25
+ )
26
+ # The batch size is lower because we have to process multiple documents per query
27
+ # This means that the batch size is effectively multiplied by the number of max_docs
28
+ train_batch_size = 8
29
+ num_epochs = 1
30
+ max_docs = 10
31
+ pad_value = -1
32
+ loss_name = "listnet"
33
+ num_labels = 1
34
+
35
+ # 1. Define our CrossEncoder model
36
+ model = CrossEncoder(model_name, num_labels=num_labels)
37
+ print("Model max length:", model.max_length)
38
+ print("Model num labels:", model.num_labels)
39
+
40
+ # 2. Load the MS MARCO dataset: https://huggingface.co/datasets/microsoft/ms_marco
41
+ logging.info("Read train dataset")
42
+ dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train")
43
+
44
+ def listwise_mapper(batch, max_docs: int = 10, pad_value: int = -1):
45
+ processed_queries = []
46
+ processed_docs = []
47
+ processed_labels = []
48
+
49
+ for query, passages_info in zip(batch["query"], batch["passages"]):
50
+ # Extract passages and labels
51
+ passages = passages_info["passage_text"]
52
+ labels = passages_info["is_selected"]
53
+
54
+ # Pair passages with labels and sort descending by label (positives first)
55
+ paired = sorted(zip(passages, labels), key=lambda x: x[1], reverse=True)
56
+
57
+ # Separate back to passages and labels
58
+ sorted_passages, sorted_labels = zip(*paired) if paired else ([], [])
59
+
60
+ # Filter queries without any positive labels
61
+ if max(sorted_labels) < 1.0:
62
+ continue
63
+
64
+ # Truncate to max_docs
65
+ truncated_passages = list(sorted_passages[:max_docs])
66
+ truncated_labels = list(sorted_labels[:max_docs])
67
+
68
+ # Pad if needed
69
+ pad_count = max_docs - len(truncated_passages)
70
+ processed_docs.append(truncated_passages + [""] * pad_count)
71
+ processed_labels.append(truncated_labels + [pad_value] * pad_count)
72
+ processed_queries.append(query)
73
+
74
+ return {
75
+ "query": processed_queries,
76
+ "docs": processed_docs,
77
+ "labels": processed_labels,
78
+ }
79
+
80
+ dataset = dataset.map(
81
+ lambda batch: listwise_mapper(batch=batch, max_docs=max_docs, pad_value=pad_value),
82
+ batched=True,
83
+ remove_columns=dataset.column_names,
84
+ desc="Processing listwise samples",
85
+ )
86
+
87
+ dataset = dataset.train_test_split(test_size=10_000)
88
+ train_dataset = dataset["train"]
89
+ eval_dataset = dataset["test"]
90
+ logging.info(train_dataset)
91
+
92
+ # 3. Define our training loss
93
+ loss = ListNetLoss(model, pad_value=pad_value)
94
+
95
+ # 4. Define the evaluator. We use the CENanoBEIREvaluator, which is a light-weight evaluator for English reranking
96
+ evaluator = CENanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=train_batch_size)
97
+ evaluator(model)
98
+
99
+ # 5. Define the training arguments
100
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
101
+ run_name = f"reranker-msmarco-v1.1-{short_model_name}-{loss_name}"
102
+ args = CrossEncoderTrainingArguments(
103
+ # Required parameter:
104
+ output_dir=f"models/{run_name}",
105
+ # Optional training parameters:
106
+ num_train_epochs=num_epochs,
107
+ per_device_train_batch_size=train_batch_size,
108
+ per_device_eval_batch_size=train_batch_size,
109
+ learning_rate=2e-5,
110
+ warmup_ratio=0.1,
111
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
112
+ bf16=True, # Set to True if you have a GPU that supports BF16
113
+ # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
114
+ load_best_model_at_end=True,
115
+ metric_for_best_model="eval_NanoBEIR_mean_ndcg@10",
116
+ # Optional tracking/debugging parameters:
117
+ eval_strategy="steps",
118
+ eval_steps=1600,
119
+ save_strategy="steps",
120
+ save_steps=1600,
121
+ save_total_limit=2,
122
+ logging_steps=200,
123
+ logging_first_step=True,
124
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
125
+ seed=12,
126
+ )
127
+
128
+ # 6. Create the trainer & start training
129
+ trainer = CrossEncoderTrainer(
130
+ model=model,
131
+ args=args,
132
+ train_dataset=train_dataset,
133
+ eval_dataset=eval_dataset,
134
+ loss=loss,
135
+ evaluator=evaluator,
136
+ )
137
+ trainer.train()
138
+
139
+ # 7. Evaluate the final model, useful to include these in the model card
140
+ evaluator(model)
141
+
142
+ # 8. Save the final model
143
+ final_output_dir = f"models/{run_name}/final"
144
+ model.save_pretrained(final_output_dir)
145
+
146
+ # 9. (Optional) save the model to the Hugging Face Hub!
147
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
148
+ try:
149
+ model.push_to_hub(run_name)
150
+ except Exception:
151
+ logging.error(
152
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
153
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
154
+ f"and saving it using `model.push_to_hub('{run_name}')`."
155
+ )
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()