|
import torch.nn as nn |
|
import torch |
|
from transformers import AutoTokenizer, BertForSequenceClassification, PreTrainedModel, PretrainedConfig, get_scheduler |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from torch.nn import CrossEntropyLoss |
|
from torch.optim import AdamW |
|
from LUKE_pipe import generate |
|
from datasets import load_dataset |
|
from accelerate import Accelerator |
|
from tqdm import tqdm |
|
|
|
MAX_BEAM = 10 |
|
tf32 = True |
|
torch.backends.cuda.matmul.allow_tf32 = tf32 |
|
torch.backends.cudnn.allow_tf32 = tf32 |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
|
class ClassifierAdapter(nn.Module): |
|
def __init__(self, l1=3): |
|
super().__init__() |
|
self.linear1 = nn.Linear(l1, 1) |
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
self.bert = BertForSequenceClassification.from_pretrained("botcon/right_span_bert") |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, questions, answers, logits): |
|
beam_size = len(answers[0]) |
|
samples = len(questions) |
|
questions = [question for _ in range(len(answers[0])) for question in questions] |
|
answers = [answer for beam in answers for answer in beam] |
|
input = self.tokenizer( |
|
questions, |
|
answers, |
|
padding="max_length", |
|
return_tensors="pt" |
|
).to(device) |
|
bert_logits = self.bert(**input).logits |
|
bert_logits = bert_logits.reshape(samples, beam_size, 2) |
|
logits = torch.FloatTensor(logits).to(device).unsqueeze(-1) |
|
logits = torch.cat((logits, bert_logits), dim=-1) |
|
logits = self.relu(logits) |
|
out = torch.squeeze(self.linear1(logits), dim=-1) |
|
return out |
|
|
|
class HuggingWrapper(PreTrainedModel): |
|
config_class = PretrainedConfig() |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = ClassifierAdapter() |
|
|
|
def forward(self, **kwargs): |
|
labels = kwargs.pop("labels") |
|
output = self.model(**kwargs) |
|
loss_fn = CrossEntropyLoss(ignore_index=MAX_BEAM) |
|
loss = loss_fn(output, labels) |
|
return SequenceClassifierOutput(logits=output, loss=loss) |
|
|
|
accelerator = Accelerator(mixed_precision="fp16") |
|
model = HuggingWrapper.from_pretrained("botcon/special_bert").to(device) |
|
optimizer = AdamW(model.parameters()) |
|
model, optimizer = accelerator.prepare(model, optimizer) |
|
batch_size = 2 |
|
raw_datasets = load_dataset("squad") |
|
raw_train = raw_datasets["train"] |
|
num_updates = len(raw_train) // batch_size |
|
num_epoch = 2 |
|
num_training_steps = num_updates * num_epoch |
|
lr_scheduler = get_scheduler( |
|
"linear", |
|
optimizer=optimizer, |
|
num_warmup_steps=0, |
|
num_training_steps=num_training_steps, |
|
) |
|
|
|
progress_bar = tqdm(range(num_training_steps)) |
|
|
|
for epoch in range(num_epoch): |
|
start = 0 |
|
end = batch_size |
|
steps = 0 |
|
cumu_loss = 0 |
|
training_data = raw_train |
|
model.train() |
|
while start < len(training_data): |
|
optimizer.zero_grad() |
|
batch_data = raw_train.select(range(start, min(end, len(raw_train)))) |
|
with torch.no_grad(): |
|
res = generate(batch_data) |
|
prediction = [] |
|
predicted_logit = [] |
|
labels = [] |
|
for i in range(len(res)): |
|
x = res[i] |
|
ground_answer = batch_data["answers"][i]["text"][0] |
|
predicted_text = x["prediction_text"] |
|
found = False |
|
for k in range(len(predicted_text)): |
|
if predicted_text[k] == ground_answer: |
|
labels.append(k) |
|
found = True |
|
break |
|
if not found: |
|
labels.append(MAX_BEAM) |
|
prediction.append(predicted_text) |
|
predicted_logit.append(x["logits"]) |
|
labels = torch.LongTensor(labels).to(device) |
|
classifier_out = model(questions=batch_data["question"] , answers=prediction, logits=predicted_logit, labels=labels) |
|
loss = classifier_out.loss |
|
if not torch.isnan(loss).item(): |
|
cumu_loss += loss.item() |
|
steps += 1 |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
progress_bar.update(1) |
|
start += batch_size |
|
end += batch_size |
|
|
|
if steps % 100 == 0: |
|
print("Cumu loss: {}".format(cumu_loss / 100)) |
|
cumu_loss = 0 |
|
|
|
model.push_to_hub("Adapter Bert") |