File size: 1,427 Bytes
c2c125c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

AFQMC_LABELS = {
    '0': '0', 
    '1': '1', 
}

CSL_LABELS = {
    '0': '0', 
    '1': '1', 
    '2': '2', 
}

IFLYTEK_LABELS = {}
for i in range(119):
    IFLYTEK_LABELS[str(i)] = str(i)

OCNLI_LABELS = {
    'contradiction': '0', 
    'entailment': '1', 
    'neutral': '2'
}

CMNLI_LABELS = {
    'contradiction': '0', 
    'entailment': '1', 
    'neutral': '2'
}

TNEWS_LABELS = {}
tnews_list = []
for i in range(17):
    if i == 5 or i == 11:
        continue
    tnews_list.append(i)
for i in range(len(tnews_list)):
    TNEWS_LABELS[str(100 + tnews_list[i])] = str(i)

WSC_LABELS = {
    'true': '0', 
    'false': '1', 
}

ZC_LABELS = {
    'negative': '0', 
    'positive': '1', 
}

def get_label_dict(task_name, write2file=False):
    
    if task_name == "AFQMC":
        label_dict = AFQMC_LABELS
    elif task_name == "CSL":
        label_dict = CSL_LABELS
    elif task_name == "IFLYTEK":
        label_dict = IFLYTEK_LABELS
    elif task_name == "OCNLI":
        label_dict = OCNLI_LABELS
    elif task_name == "TNEWS":
        label_dict = TNEWS_LABELS
    elif task_name == "WSC":
        label_dict = WSC_LABELS
    elif task_name == "CMNLI":
        label_dict = CMNLI_LABELS
    elif task_name == "ZC":
        label_dict = ZC_LABELS
    else:
        print("Not Imp")
        import pdb;pdb.set_trace()

    if write2file:
        label_dict = {v:k for k,v in label_dict.items()}

    return label_dict