from preprocess import Model, SquadDataset from transformers import DistilBertForQuestionAnswering from torch.utils.data import DataLoader from transformers import AdamW import torch import subprocess data = Model() data.ModelExecution() # train_contexts, train_questions, train_answers = data.ArrangeData("livecheckcontainer") # val_contexts, val_questions, val_answers = data.ArrangeData("livecheckcontainer") # print(train_answers) # train_answers, train_contexts = data.add_end_idx(train_answers, train_contexts) # val_answers, val_contexts = data.add_end_idx(val_answers, val_contexts) # train_encodings, val_encodings = data.Tokenizer(train_contexts, train_questions, val_contexts, val_questions) # train_encodings = data.add_token_positions(train_encodings, train_answers) # val_encodings = data.add_token_positions(val_encodings, val_answers) # train_dataset = SquadDataset(train_encodings) # val_dataset = SquadDataset(val_encodings) # model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased") # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # model.to(device) # model.train() # train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) # optim = AdamW(model.parameters(), lr=5e-5) # for epoch in range(2): # print(epoch) # for batch in train_loader: # optim.zero_grad() # input_ids = batch['input_ids'].to(device) # attention_mask = batch['attention_mask'].to(device) # start_positions = batch['start_positions'].to(device) # end_positions = batch['end_positions'].to(device) # outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) # loss = outputs[0] # loss.backward() # optim.step() # print("Done") # model.eval() # model.save_pretrained("./") # data.tokenizer.save_pretrained("./") # subprocess.call(["git", "add","--all"]) # subprocess.call(["git", "status"]) # subprocess.call(["git", "commit", "-m", "First version of the your-model-name model and tokenizer."]) # subprocess.call(["git", "push"])