QA / preprocess.py
Ateeb's picture
First version of the your-model-name model and tokenizer.
60f8cd4
raw
history blame
No virus
3.63 kB
import json
from os import close
from pathlib import Path
from azure.cosmos import CosmosClient, PartitionKey, exceptions
from transformers import DistilBertTokenizerFast
import torch
class Model:
def __init__(self) -> None:
self.endPoint = "https://productdevelopmentstorage.documents.azure.com:443/"
self.primaryKey = "nVds9dPOkPuKu8RyWqigA1DIah4SVZtl1DIM0zDuRKd95an04QC0qv9TQIgrdtgluZo7Z0HXACFQgKgOQEAx1g=="
self.client = CosmosClient(self.endPoint, self.primaryKey)
self.tokenizer = None
def GetData(self, type):
database = self.client.get_database_client("squadstorage")
container = database.get_container_client(type)
item_list = list(container.read_all_items(max_item_count=10))
return item_list
def ArrangeData(self, type):
squad_dict = self.GetData(type)
contexts = []
questions = []
answers = []
for i in squad_dict:
contexts.append(i["context"])
questions.append(i["question"])
answers.append(i["answers"])
return contexts, questions, answers
def add_end_idx(self, answers, contexts):
for answer, context in zip(answers, contexts):
gold_text = answer['text'][0]
start_idx = answer['answer_start'][0]
end_idx = start_idx + len(gold_text)
if context[start_idx:end_idx] == gold_text:
answer['answer_end'] = end_idx
elif context[start_idx-1:end_idx-1] == gold_text:
answer['answer_start'] = start_idx - 1
answer['answer_end'] = end_idx - 1 # When the gold label is off by one character
elif context[start_idx-2:end_idx-2] == gold_text:
answer['answer_start'] = start_idx - 2
answer['answer_end'] = end_idx - 2 # When the gold label is off by two characters
return answers, contexts
def Tokenizer(self, train_contexts, train_questions, val_contexts, val_questions):
self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = self.tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = self.tokenizer(val_contexts, val_questions, truncation=True, padding=True)
return train_encodings, val_encodings
def add_token_positions(self, encodings, answers):
start_positions = []
end_positions = []
for i in range(len(answers)):
start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'][0]))
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
# if start position is None, the answer passage has been truncated
if start_positions[-1] is None:
start_positions[-1] = self.tokenizer.model_max_length
if end_positions[-1] is None:
end_positions[-1] = self.tokenizer.model_max_length
encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
return encodings
# train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json')
# val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json')
class SquadDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
def __len__(self):
return len(self.encodings.input_ids)