Dinesh310 commited on
Commit
88d9f70
·
verified ·
1 Parent(s): aee05e1

Create training/train.py

Browse files
Files changed (1) hide show
  1. training/train.py +71 -0
training/train.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import evaluate
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ MT5ForConditionalGeneration,
6
+ MT5Tokenizer,
7
+ Seq2SeqTrainingArguments,
8
+ Seq2SeqTrainer,
9
+ DataCollatorForSeq2Seq
10
+ )
11
+
12
+ # Load metrics
13
+ cer_metric = evaluate.load("cer")
14
+ wer_metric = evaluate.load("wer")
15
+
16
+ model_nm = "google/mt5-small"
17
+ tokenizer = MT5Tokenizer.from_pretrained(model_nm)
18
+ model = MT5ForConditionalGeneration.from_pretrained(model_nm)
19
+
20
+ def compute_metrics(eval_preds):
21
+ preds, labels = eval_preds
22
+ if isinstance(preds, tuple):
23
+ preds = preds[0]
24
+
25
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
26
+ # Replace -100 in labels as we can't decode them
27
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
28
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
29
+
30
+ cer = cer_metric.compute(predictions=decoded_preds, references=decoded_labels)
31
+ wer = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
32
+
33
+ return {"cer": cer, "wer": wer}
34
+
35
+ def tokenize_fn(batch):
36
+ inputs = tokenizer(batch['source'], padding="max_length", truncation=True, max_length=64)
37
+ labels = tokenizer(batch['target'], padding="max_length", truncation=True, max_length=64)
38
+ inputs["labels"] = labels["input_ids"]
39
+ return inputs
40
+
41
+ # Load and process data
42
+ dataset = load_dataset('csv', data_files={'train': 'train.csv', 'test': 'val.csv'})
43
+ tokenized_dataset = dataset.map(tokenize_fn, batched=True)
44
+
45
+ args = Seq2SeqTrainingArguments(
46
+ output_dir="./translit-results",
47
+ evaluation_strategy="epoch",
48
+ learning_rate=2e-4,
49
+ per_device_train_batch_size=16,
50
+ per_device_eval_batch_size=16,
51
+ weight_decay=0.01,
52
+ save_total_limit=2,
53
+ num_train_epochs=3,
54
+ predict_with_generate=True,
55
+ fp16=True, # Set to False if not using GPU
56
+ logging_steps=100,
57
+ )
58
+
59
+ trainer = Seq2SeqTrainer(
60
+ model=model,
61
+ args=args,
62
+ train_dataset=tokenized_dataset["train"],
63
+ eval_dataset=tokenized_dataset["test"],
64
+ tokenizer=tokenizer,
65
+ data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
66
+ compute_metrics=compute_metrics
67
+ )
68
+
69
+ trainer.train()
70
+ model.save_pretrained("./final_model")
71
+ tokenizer.save_pretrained("./final_model")