import json from os import close from pathlib import Path from azure.cosmos import CosmosClient, PartitionKey, exceptions from transformers import DistilBertTokenizerFast import torch class Model: def __init__(self) -> None: self.endPoint = "https://productdevelopmentstorage.documents.azure.com:443/" self.primaryKey = "nVds9dPOkPuKu8RyWqigA1DIah4SVZtl1DIM0zDuRKd95an04QC0qv9TQIgrdtgluZo7Z0HXACFQgKgOQEAx1g==" self.client = CosmosClient(self.endPoint, self.primaryKey) self.tokenizer = None def GetData(self, type): database = self.client.get_database_client("squadstorage") container = database.get_container_client(type) item_list = list(container.read_all_items(max_item_count=10)) return item_list def ArrangeData(self, type): squad_dict = self.GetData(type) contexts = [] questions = [] answers = [] for i in squad_dict: contexts.append(i["context"]) questions.append(i["question"]) answers.append(i["answers"]) return contexts, questions, answers def add_end_idx(self, answers, contexts): for answer, context in zip(answers, contexts): gold_text = answer['text'][0] start_idx = answer['answer_start'][0] end_idx = start_idx + len(gold_text) if context[start_idx:end_idx] == gold_text: answer['answer_end'] = end_idx elif context[start_idx-1:end_idx-1] == gold_text: answer['answer_start'] = start_idx - 1 answer['answer_end'] = end_idx - 1 # When the gold label is off by one character elif context[start_idx-2:end_idx-2] == gold_text: answer['answer_start'] = start_idx - 2 answer['answer_end'] = end_idx - 2 # When the gold label is off by two characters return answers, contexts def Tokenizer(self, train_contexts, train_questions, val_contexts, val_questions): self.tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') train_encodings = self.tokenizer(train_contexts, train_questions, truncation=True, padding=True) val_encodings = self.tokenizer(val_contexts, val_questions, truncation=True, padding=True) return train_encodings, val_encodings def add_token_positions(self, encodings, answers): start_positions = [] end_positions = [] for i in range(len(answers)): start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'][0])) end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1)) # if start position is None, the answer passage has been truncated if start_positions[-1] is None: start_positions[-1] = self.tokenizer.model_max_length if end_positions[-1] is None: end_positions[-1] = self.tokenizer.model_max_length encodings.update({'start_positions': start_positions, 'end_positions': end_positions}) return encodings # train_contexts, train_questions, train_answers = read_squad('squad/train-v2.0.json') # val_contexts, val_questions, val_answers = read_squad('squad/dev-v2.0.json') class SquadDataset(torch.utils.data.Dataset): def __init__(self, encodings): self.encodings = encodings def __getitem__(self, idx): return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} def __len__(self): return len(self.encodings.input_ids)