Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
@@ -237,15 +237,20 @@ def convert_to_markdown(text):
|
|
237 |
return markdown_text
|
238 |
|
239 |
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
242 |
|
243 |
-
def
|
244 |
-
|
|
|
|
|
|
|
245 |
|
246 |
-
def
|
247 |
-
self.
|
248 |
-
shared_state = State()
|
249 |
|
250 |
|
251 |
#######################################################
|
|
|
237 |
return markdown_text
|
238 |
|
239 |
|
240 |
+
#Datasets encodieren - in train und val Sets
|
241 |
+
class Dataset(torch.utils.data.Dataset):
|
242 |
+
def __init__(self, encodings, labels=None):
|
243 |
+
self.encodings = encodings
|
244 |
+
self.labels = labels
|
245 |
|
246 |
+
def __getitem__(self, idx):
|
247 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
248 |
+
if self.labels:
|
249 |
+
item["labels"] = torch.tensor(self.labels[idx])
|
250 |
+
return item
|
251 |
|
252 |
+
def __len__(self):
|
253 |
+
return len(self.encodings["input_ids"])
|
|
|
254 |
|
255 |
|
256 |
#######################################################
|