File size: 933 Bytes
fb4a3c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import Trainer, TrainingArguments, T5ForConditionalGeneration

def train_model(tok_ds,num_train_epochs,batch_size):
    model = T5ForConditionalGeneration.from_pretrained('t5-base')
    training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=2,
    num_train_epochs=num_train_epochs,
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    fp16=True
    )
    trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tok_ds["train"],
    eval_dataset=tok_ds["validation"],
    #data_collator=data_collator,
    compute_metrics=lambda p: compute_rouge_scores(
        tokenizer.batch_decode(p.predictions, skip_special_tokens=True),
        tokenizer.batch_decode(p.label_ids, skip_special_tokens=True),
        ),
    )
    trainer.train()
    return trainer