def get_length_param(text: str, tokenizer) -> str: """Maps text to 1 of 4 buckets based on length after encoding. Parameters ---------- text: str The text to be given 1 of 4 length parameters. tokenizer: HuggingFace tokenizer Tokenizer that used to compute the length of the text after encoding. For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html Returns ------- len_param: str One of four buckets: '1' for short, '2' for medium, '3' for long texts and '-' for all others. """ tokens_count = len(tokenizer.encode(text)) if tokens_count <= 15: len_param = '1' elif tokens_count <= 50: len_param = '2' elif tokens_count <= 256: len_param = '3' else: len_param = '-' return len_param def get_user_param(text: dict, machine_name_in_chat: str) -> str: """Maps text by 1/0 for it to be the person or the machine in the dialogue Parameters ---------- text: Dict[..., 'from', ...] Dict containing field 'from' with the name of the user who sent the message machine_name_in_chat: str Str with the name of the machine - it will be predicted """ if text['from'] == machine_name_in_chat: return '1' # machine else: return '0' # human def build_text_file(data_json: dict, dest_path: str, tokenizer, machine_name_in_chat='Кирилл Гельван'): """Create a text file for training in special format for ruDialoGPT-3. Parameters ---------- data_json: dict Dict containing 'text' (message) and 'from' (user who sent the message) dest_path: str String containing path to write data there tokenizer: HuggingFace tokenizer Tokenizer that used to compute the length of the text after encoding. For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html """ f = open(dest_path, 'w') new_data = '' for i in range(len(data_json) - 1): message, next_message = data_json[i], data_json[i+1] if message['text'] == '' or type(message['text']) != str: continue if next_message['text'] == '' or type(next_message['text']) != str: continue user = get_user_param(message, machine_name_in_chat=machine_name_in_chat) length = get_length_param(data_json[i+1]['text'], tokenizer) message_text = re.sub(r"\n", ". ", message['text']) new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n" f.write(new_data) def load_dataset(train_path, test_path, tokenizer): """Creates train and test PyTorch datasets and collate_fn using HuggingFace. Parameters ---------- train_path: str String containing path to train data test_path: str String containing path to test data tokenizer: HuggingFace tokenizer Tokenizer that used to compute the length of the text after encoding. For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html """ train_dataset = TextDataset( tokenizer = tokenizer, file_path = train_path, block_size = 256) test_dataset = TextDataset( tokenizer = tokenizer, file_path = test_path, block_size = 256) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False ) return train_dataset, test_dataset, data_collator