|
import json |
|
import pandas as pd |
|
import datasets |
|
import numpy as np |
|
import evaluate |
|
import torch |
|
from transformers import AutoModel, DistilBertForSequenceClassification |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from typing import Optional |
|
|
|
SEP_TOKEN = '[SEP]' |
|
LABEL2ID = {'entailment': 2, 'neutral': 1, 'contradiction': 0} |
|
ID2LABEL = {2: 'entailment', 1: 'neutral', 0: 'contradiction'} |
|
|
|
def format_dataset(arr): |
|
text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr] |
|
label = [LABEL2ID[el['label']] for el in arr] |
|
new_df = pd.DataFrame({'text': text, 'label': label}) |
|
return new_df.sample(frac=1, random_state=42).reset_index(drop=True) |
|
|
|
|
|
def load_dataset(path): |
|
train_array = [] |
|
with open(path) as f: |
|
for line in f.readlines(): |
|
if line: |
|
train_array.append(json.loads(line)) |
|
df = format_dataset(train_array) |
|
|
|
df_train = df.iloc[512:, :] |
|
|
|
df_test = df.iloc[:512, :] |
|
print(df_train[:10]) |
|
print(df_test[:10]) |
|
|
|
factual_consistency_dataset = datasets.dataset_dict.DatasetDict() |
|
factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas( |
|
df_train[["text", "label"]]) |
|
factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas( |
|
df_test[["text", "label"]]) |
|
|
|
return factual_consistency_dataset |
|
|
|
|
|
class ConsistentSentenceClassifier(DistilBertForSequenceClassification): |
|
|
|
def __init__(self, freeze_bert=True): |
|
base_model = AutoModel.from_pretrained( |
|
'line-corporation/line-distilbert-base-japanese', num_labels=3) |
|
|
|
config = base_model.config |
|
super(ConsistentSentenceClassifier, self).__init__(config=config) |
|
config.num_labels = 3 |
|
config.id2label = ID2LABEL |
|
config.label2id = LABEL2ID |
|
config.problem_type = "single_label_classification" |
|
|
|
self.distilbert = base_model |
|
|
|
if not freeze_bert: |
|
return |
|
|
|
for param in self.distilbert.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
distilbert_output = self.distilbert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_state = distilbert_output[0] |
|
pooled_output = torch.mean(hidden_state, dim=1) |
|
pooled_output = self.pre_classifier(pooled_output) |
|
pooled_output = torch.nn.ReLU()(pooled_output) |
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + distilbert_output[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=distilbert_output.hidden_states, |
|
attentions=distilbert_output.attentions, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_metrics(): |
|
metric = evaluate.load("accuracy") |
|
|
|
def compute_metrics(eval_pred): |
|
predictions, labels = eval_pred |
|
preds = predictions[0].argmax(axis=1) |
|
return metric.compute(predictions=preds, references=labels) |
|
|
|
return compute_metrics |
|
|