File size: 4,312 Bytes
5473c52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import pandas as pd
import datasets
import numpy as np
import evaluate
import torch
from transformers import AutoModel, DistilBertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional

SEP_TOKEN = '[SEP]'
LABEL2ID = {'entailment': 2, 'neutral': 1, 'contradiction': 0}
ID2LABEL = {2: 'entailment', 1: 'neutral', 0: 'contradiction'}

def format_dataset(arr):
    text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
    label = [LABEL2ID[el['label']] for el in arr]
    new_df = pd.DataFrame({'text': text, 'label': label})
    return new_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Load dataset
def load_dataset(path):
    train_array = []
    with open(path) as f:
        for line in f.readlines():
            if line:
                train_array.append(json.loads(line))
    df = format_dataset(train_array)
    # Split dataset into train and val
    df_train = df.iloc[512:, :]
    # We do not need much test data
    df_test = df.iloc[:512, :]
    print(df_train[:10])
    print(df_test[:10])

    factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
    factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
        df_train[["text", "label"]])
    factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
        df_test[["text", "label"]])

    return factual_consistency_dataset


class ConsistentSentenceClassifier(DistilBertForSequenceClassification):

    def __init__(self, freeze_bert=True):
        base_model = AutoModel.from_pretrained(
            'line-corporation/line-distilbert-base-japanese', num_labels=3)

        config = base_model.config
        super(ConsistentSentenceClassifier, self).__init__(config=config)
        config.num_labels = 3
        config.id2label = ID2LABEL
        config.label2id = LABEL2ID
        config.problem_type = "single_label_classification"

        self.distilbert = base_model

        if not freeze_bert:
            return

        for param in self.distilbert.parameters():
            param.requires_grad = False

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = torch.mean(hidden_state, dim=1)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = torch.nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, num_labels)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + distilbert_output[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
        )



# Set up evaluation metridef get_metrics():

def get_metrics():
    metric = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        preds = predictions[0].argmax(axis=1)
        return metric.compute(predictions=preds, references=labels)

    return compute_metrics