alexkueck commited on
Commit
a7ac19b
·
1 Parent(s): 03f2529

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +12 -7
utils.py CHANGED
@@ -237,15 +237,20 @@ def convert_to_markdown(text):
237
  return markdown_text
238
 
239
 
240
- class State:
241
- interrupted = False
 
 
 
242
 
243
- def interrupt(self):
244
- self.interrupted = True
 
 
 
245
 
246
- def recover(self):
247
- self.interrupted = False
248
- shared_state = State()
249
 
250
 
251
  #######################################################
 
237
  return markdown_text
238
 
239
 
240
+ #Datasets encodieren - in train und val Sets
241
+ class Dataset(torch.utils.data.Dataset):
242
+ def __init__(self, encodings, labels=None):
243
+ self.encodings = encodings
244
+ self.labels = labels
245
 
246
+ def __getitem__(self, idx):
247
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
248
+ if self.labels:
249
+ item["labels"] = torch.tensor(self.labels[idx])
250
+ return item
251
 
252
+ def __len__(self):
253
+ return len(self.encodings["input_ids"])
 
254
 
255
 
256
  #######################################################