File size: 539 Bytes
4ad32d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from typing import List
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
def dataset_split_selector(data) -> List:
"""
This is a function for automating the process of selecting data split.
Will be further updated.
"""
if len(data.keys()) == 1:
return ['train']
else:
if 'train_prefs' in data.keys():
return ['train_prefs', 'test_prefs']
else:
return ['train', 'test'] |