File size: 4,505 Bytes
77961b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d65e913
 
 
77961b6
 
 
 
 
 
d65e913
 
 
77961b6
d65e913
 
 
77961b6
 
d65e913
77961b6
d65e913
77961b6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import datasets

import logging

import pandas as pd


def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
    for model_label in id2label_mapping.keys():
        if model_label.upper() == label.upper():
            return model_label, label
    return None, label


def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
    id2label_mapping = {id2label[k]: None for k in id2label.keys()}
    dataset_labels = None
    for feature in dataset_features.values():
        if not isinstance(feature, datasets.ClassLabel):
            continue
        if len(feature.names) != len(id2label_mapping.keys()):
            continue

        dataset_labels = feature.names

        # Try to match labels
        for label in feature.names:
            if label in id2label_mapping.keys():
                model_label = label
            else:
                # Try to find case unsensative
                model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
            if model_label is not None:
                id2label_mapping[model_label] = label

    return id2label_mapping, dataset_labels


def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
    # We assume dataset is ok here
    ds = datasets.load_dataset(d_id, config)[split]

    try:
        dataset_features = ds.features
    except AttributeError:
        # Dataset does not have features, need to provide everything
        return None, None, None

    # Check whether we need to infer the text input column
    infer_text_input_column = True
    if "text" in column_mapping.keys():
        dataset_text_column = column_mapping["text"]
        if dataset_text_column in dataset_features.keys():
            infer_text_input_column = False
        else:
            logging.warning(f"Provided {dataset_text_column} is not in Dataset columns")

    if infer_text_input_column:
        # Try to retrieve one
        candidates = [f for f in dataset_features if dataset_features[f].dtype == "string"]
        if len(candidates) > 0:
            logging.debug(f"Candidates are {candidates}")
            column_mapping["text"] = candidates[0]
        else:
            # Not found a text feature
            return column_mapping, None, None

    # Load dataset as DataFrame
    df = ds.to_pandas()

    # Retrieve all labels
    id2label_mapping = {}
    id2label = ppl.model.config.id2label
    label2id = {v: k for k, v in id2label.items()}
    prediction_result = None
    try:
        # Use the first item to test prediction
        results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
        prediction_result = {
            f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
        }
    except Exception:
        # Pipeline prediction failed, need to provide labels
        return column_mapping, None, None

    # Infer labels
    id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
    if "label" in column_mapping.keys():
        if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
            logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
            return column_mapping, prediction_result, None

        if isinstance(column_mapping["label"], dict):
            # Use the column mapping passed by user
            for i, model_label in column_mapping["label"].items():
                id2label_mapping[model_label] = dataset_labels[int(i)]
    elif None in id2label_mapping.values():
        column_mapping["label"] = {
            i: None for i in id2label.keys()
        }
        return column_mapping, prediction_result, None

    id2label_mapping = {
        v: k for k, v in id2label_mapping.items()
    }
    id2label_df = pd.DataFrame({
        "ID": list(range(len(dataset_labels))),
        "Labels": dataset_labels,
        "Labels in original model": [f"{id2label_mapping[label]}({label2id[id2label_mapping[label]]})" for label in dataset_labels],
    })
    if "label" not in column_mapping.keys():
        # Column mapping should contain original model labels
        column_mapping["label"] = {
            str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
        }

    return column_mapping, prediction_result, id2label_df