sujeshpadhi commited on
Commit
5519365
·
1 Parent(s): f3d3f76

Upload T5.py

Browse files
Files changed (1) hide show
  1. T5.py +119 -0
T5.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
3
+ from torch import Tensor
4
+ from torch.nn import Module
5
+ from typing import List, Optional, Tuple
6
+ import torch, os
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class T5(Module):
11
+ '''
12
+ T5 model from: https://huggingface.co/docs/transformers/model_doc/t5
13
+ '''
14
+
15
+ def __init__(self,
16
+ variant:str="t5-small",
17
+ max_source_length:int=256,
18
+ max_target_length:int=128,
19
+ optimizer_config:dict={},
20
+ ):
21
+
22
+ # Assertions
23
+ assert variant in ["t5-small", "t5-base", "t5-large"]
24
+
25
+ super().__init__()
26
+
27
+ self.variant = variant
28
+ self.max_source_length = max_source_length
29
+ self.max_target_length = max_target_length
30
+
31
+ # Tokenizer & model
32
+ self.tokenizer = T5TokenizerFast.from_pretrained(self.variant,
33
+ model_max_length=self.max_source_length)
34
+ self.model = T5ForConditionalGeneration.from_pretrained(self.variant)
35
+
36
+ # Optimizer
37
+ self.optimizer = torch.optim.AdamW(self.parameters(), **optimizer_config)
38
+
39
+ # Scheduler
40
+ self.scheduler = None
41
+
42
+
43
+ def tokenize(self, input:List[str]):
44
+
45
+ out = self.tokenizer(input, max_length=self.max_source_length,
46
+ truncation=True, padding=True,
47
+ return_tensors="pt")
48
+
49
+ return out.input_ids.cuda(), out.attention_mask.cuda()
50
+
51
+
52
+ def forward(self, input:List[str], label:Optional[List[str]]=None) -> Tuple[Tensor, Optional[Tensor]]:
53
+
54
+ '''
55
+ Will receive input and target string and produce the final output as tensor (not decoded)
56
+ when target is not None, it will give the loss functions with the output as tuple
57
+ '''
58
+
59
+ input_ids, input_masks = self.tokenize(input)
60
+
61
+ if label is not None:
62
+ label_ids, label_masks = self.tokenize(label)
63
+ output = self.model(input_ids=input_ids, labels=label_ids)
64
+ return output.logits, output.loss
65
+
66
+ return self.model.generate(input_ids=input_ids,
67
+ max_new_tokens=self.max_target_length), None
68
+
69
+
70
+
71
+ def predict(self, input:List[str]) -> List[str]:
72
+
73
+ '''
74
+ Will generate the target output as string
75
+ '''
76
+
77
+ logits, loss = self.forward(input=input)
78
+ return self.tokenizer.batch_decode(logits, skip_special_tokens=True)
79
+
80
+
81
+
82
+ if __name__ == '__main__':
83
+
84
+ '''
85
+ Implement a tester class similar to T5-old.py to test if it works
86
+ '''
87
+
88
+ model = T5('t5-small')
89
+ model.to('cuda')
90
+
91
+ #inputs = [
92
+ #"translate English to German: Thank you so much, Chris.",
93
+ #"translate English to German: I have been blown away by this conference, and I want to thank all of you for the many nice comments about what I had to say the other night.",
94
+ #"translate German to English: Es ist mir wirklich eine Ehre, zweimal auf dieser Bühne stehen zu dürfen. Tausend Dank dafür.",
95
+ #]
96
+
97
+ #targets = [
98
+ #"Vielen Dank, Chris.",
99
+ #"Ich bin wirklich begeistert von dieser Konferenz, und ich danke Ihnen allen für die vielen netten Kommentare zu meiner Rede vorgestern Abend.",
100
+ #"And it's truly a great honor to have the opportunity to come to this stage twice; I'm extremely grateful.",
101
+ #]
102
+
103
+ inputs = ["Good Morning, How are you?"]
104
+ targets = ["Buongiorno, come stai?"]
105
+
106
+ logits, loss = model.forward(inputs, targets)
107
+ print('Model forward')
108
+ print('logits: ', logits)
109
+ print('loss: ', loss)
110
+
111
+ outputs = model.predict(inputs)
112
+
113
+ #print('OUTPUT')
114
+ #print(outputs)
115
+ for (inp, out), tar in zip(zip(inputs, outputs), targets):
116
+ print(f"Input: \n{inp}\n\nOutput: \n{out}\n\nTarget: \n{tar}\n\n")
117
+
118
+
119
+