File size: 539 Bytes
4ad32d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

from typing import List

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        logits = logits[0]
    return logits.argmax(dim=-1)

def dataset_split_selector(data) -> List:
    """
    This is a function for automating the process of selecting data split.
    Will be further updated.
    """
    if len(data.keys()) == 1:
        return ['train']
    else:
        if 'train_prefs' in data.keys():
            return ['train_prefs', 'test_prefs']
        else:
            return ['train', 'test']