QA / main.py
Ateeb's picture
Updated version of the your-model-name model and tokenizer.
00974c5
raw
history blame contribute delete
No virus
2.06 kB
from preprocess import Model, SquadDataset
from transformers import DistilBertForQuestionAnswering
from torch.utils.data import DataLoader
from transformers import AdamW
import torch
import subprocess
data = Model()
train_contexts, train_questions, train_answers = data.ArrangeData("livecheckcontainer")
val_contexts, val_questions, val_answers = data.ArrangeData("livecheckcontainer")
print(train_answers)
train_answers, train_contexts = data.add_end_idx(train_answers, train_contexts)
val_answers, val_contexts = data.add_end_idx(val_answers, val_contexts)
train_encodings, val_encodings = data.Tokenizer(train_contexts, train_questions, val_contexts, val_questions)
train_encodings = data.add_token_positions(train_encodings, train_answers)
val_encodings = data.add_token_positions(val_encodings, val_answers)
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
optim = AdamW(model.parameters(), lr=5e-5)
for epoch in range(2):
print(epoch)
for batch in train_loader:
optim.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_positions = batch['start_positions'].to(device)
end_positions = batch['end_positions'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
loss = outputs[0]
loss.backward()
optim.step()
print("Done")
model.eval()
model.save_pretrained("./")
data.tokenizer.save_pretrained("./")
subprocess.call(["git", "add","--all"])
subprocess.call(["git", "status"])
subprocess.call(["git", "commit", "-m", "First version of the your-model-name model and tokenizer."])
subprocess.call(["git", "push"])