liwii's picture
Training in progress, epoch 1
2e36ddd verified
raw
history blame
4.31 kB
import json
import pandas as pd
import datasets
import numpy as np
import evaluate
import torch
from transformers import AutoModel, DistilBertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional
SEP_TOKEN = '[SEP]'
LABEL2ID = {'entailment': 2, 'neutral': 1, 'contradiction': 0}
ID2LABEL = {2: 'entailment', 1: 'neutral', 0: 'contradiction'}
def format_dataset(arr):
text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
label = [LABEL2ID[el['label']] for el in arr]
new_df = pd.DataFrame({'text': text, 'label': label})
return new_df.sample(frac=1, random_state=42).reset_index(drop=True)
# Load dataset
def load_dataset(path):
train_array = []
with open(path) as f:
for line in f.readlines():
if line:
train_array.append(json.loads(line))
df = format_dataset(train_array)
# Split dataset into train and val
df_train = df.iloc[512:, :]
# We do not need much test data
df_test = df.iloc[:512, :]
print(df_train[:10])
print(df_test[:10])
factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
df_train[["text", "label"]])
factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
df_test[["text", "label"]])
return factual_consistency_dataset
class ConsistentSentenceClassifier(DistilBertForSequenceClassification):
def __init__(self, freeze_bert=True):
base_model = AutoModel.from_pretrained(
'line-corporation/line-distilbert-base-japanese', num_labels=3)
config = base_model.config
super(ConsistentSentenceClassifier, self).__init__(config=config)
config.num_labels = 3
config.id2label = ID2LABEL
config.label2id = LABEL2ID
config.problem_type = "single_label_classification"
self.distilbert = base_model
if not freeze_bert:
return
for param in self.distilbert.parameters():
param.requires_grad = False
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = torch.mean(hidden_state, dim=1)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
pooled_output = torch.nn.ReLU()(pooled_output) # (bs, dim)
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, num_labels)
loss = None
if labels is not None:
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
# Set up evaluation metridef get_metrics():
def get_metrics():
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = predictions[0].argmax(axis=1)
return metric.compute(predictions=preds, references=labels)
return compute_metrics