burtenshaw's picture
burtenshaw HF staff
Upload folder using huggingface_hub
4ad32d0 verified
raw
history blame contribute delete
539 Bytes
from typing import List
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
def dataset_split_selector(data) -> List:
"""
This is a function for automating the process of selecting data split.
Will be further updated.
"""
if len(data.keys()) == 1:
return ['train']
else:
if 'train_prefs' in data.keys():
return ['train_prefs', 'test_prefs']
else:
return ['train', 'test']