Gagan Bhatia commited on
Commit
22ddb9e
·
1 Parent(s): 55f1dff

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +5 -0
src/models/model.py CHANGED
@@ -79,3 +79,8 @@ class DataModule(Dataset):
79
  text=data_row["text"],
80
  keywords_input_ids=input_encoding["input_ids"].flatten(),
81
  keywords_attention_mask=input_encoding["attention_mask"].flatten(),
 
 
 
 
 
 
79
  text=data_row["text"],
80
  keywords_input_ids=input_encoding["input_ids"].flatten(),
81
  keywords_attention_mask=input_encoding["attention_mask"].flatten(),
82
+ labels_attention_mask=output_encoding["attention_mask"].flatten(),
83
+ )
84
+
85
+
86
+ class PLDataModule(LightningDataModule):