Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
•
6e46269
1
Parent(s):
9ed09a4
Update model.py
Browse files- src/models/model.py +18 -0
src/models/model.py
CHANGED
@@ -150,3 +150,21 @@ class LightningModel(LightningModule):
|
|
150 |
|
151 |
def __init__(self, tokenizer, model, output: str = "outputs"):
|
152 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
def __init__(self, tokenizer, model, output: str = "outputs"):
|
152 |
"""
|
153 |
+
initiates a PyTorch Lightning Model
|
154 |
+
Args:
|
155 |
+
tokenizer : T5 tokenizer
|
156 |
+
model : T5 model
|
157 |
+
output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
|
158 |
+
"""
|
159 |
+
super().__init__()
|
160 |
+
self.model = model
|
161 |
+
self.tokenizer = tokenizer
|
162 |
+
self.output = output
|
163 |
+
# self.val_acc = Accuracy()
|
164 |
+
# self.train_acc = Accuracy()
|
165 |
+
|
166 |
+
def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
|
167 |
+
""" forward step """
|
168 |
+
output = self.model(
|
169 |
+
input_ids,
|
170 |
+
attention_mask=attention_mask,
|