Spaces:
Runtime error
Runtime error
File size: 3,592 Bytes
21a5dba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
def get_length_param(text: str, tokenizer) -> str:
"""Maps text to 1 of 4 buckets based on length after encoding.
Parameters
----------
text: str
The text to be given 1 of 4 length parameters.
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
Returns
-------
len_param: str
One of four buckets:
'1' for short, '2' for medium, '3' for long texts and '-' for all others.
"""
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
def get_user_param(text: dict, machine_name_in_chat: str) -> str:
"""Maps text by 1/0 for it to be the person or the machine in the dialogue
Parameters
----------
text: Dict[..., 'from', ...]
Dict containing field 'from' with the name of the user who sent the message
machine_name_in_chat: str
Str with the name of the machine - it will be predicted
"""
if text['from'] == machine_name_in_chat:
return '1' # machine
else:
return '0' # human
def build_text_file(data_json: dict, dest_path: str,
tokenizer, machine_name_in_chat='Кирилл Гельван'):
"""Create a text file for training in special format for ruDialoGPT-3.
Parameters
----------
data_json: dict
Dict containing 'text' (message) and 'from' (user who sent the message)
dest_path: str
String containing path to write data there
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
"""
f = open(dest_path, 'w')
new_data = ''
for i in range(len(data_json) - 1):
message, next_message = data_json[i], data_json[i+1]
if message['text'] == '' or type(message['text']) != str:
continue
if next_message['text'] == '' or type(next_message['text']) != str:
continue
user = get_user_param(message, machine_name_in_chat=machine_name_in_chat)
length = get_length_param(data_json[i+1]['text'], tokenizer)
message_text = re.sub(r"\n", ". ", message['text'])
new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"
f.write(new_data)
def load_dataset(train_path, test_path, tokenizer):
"""Creates train and test PyTorch datasets and collate_fn using HuggingFace.
Parameters
----------
train_path: str
String containing path to train data
test_path: str
String containing path to test data
tokenizer: HuggingFace tokenizer
Tokenizer that used to compute the length of the text after encoding.
For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
"""
train_dataset = TextDataset(
tokenizer = tokenizer,
file_path = train_path,
block_size = 256)
test_dataset = TextDataset(
tokenizer = tokenizer,
file_path = test_path,
block_size = 256)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False
)
return train_dataset, test_dataset, data_collator
|