File size: 3,592 Bytes
21a5dba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def get_length_param(text: str, tokenizer) -> str:
    """Maps text to 1 of 4 buckets based on length after encoding.

    Parameters
    ----------
    text: str
        The text to be given 1 of 4 length parameters.

    tokenizer: HuggingFace tokenizer 
        Tokenizer that used to compute the length of the text after encoding.
        For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html

    Returns
    -------
    len_param: str
        One of four buckets: 
        '1' for short, '2' for medium, '3' for long texts and '-' for all others. 
    """
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param


def get_user_param(text: dict, machine_name_in_chat: str) -> str:
    """Maps text by 1/0 for it to be the person or the machine in the dialogue

    Parameters
    ----------
    text: Dict[..., 'from', ...]
        Dict containing field 'from' with the name of the user who sent the message

    machine_name_in_chat: str
        Str with the name of the machine - it will be predicted
    """
    if text['from'] == machine_name_in_chat:
        return '1'  # machine
    else:
        return '0'  # human


def build_text_file(data_json: dict, dest_path: str, 
                    tokenizer, machine_name_in_chat='Кирилл Гельван'):
    """Create a text file for training in special format for ruDialoGPT-3.

    Parameters
    ----------
    data_json: dict
        Dict containing 'text' (message) and 'from' (user who sent the message)
        
    dest_path: str
        String containing path to write data there

    tokenizer: HuggingFace tokenizer 
        Tokenizer that used to compute the length of the text after encoding.
        For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
    """
    f = open(dest_path, 'w')
    new_data = ''
    for i in range(len(data_json) - 1):
        message, next_message = data_json[i], data_json[i+1]
        if message['text'] == '' or type(message['text']) != str:
            continue
        if next_message['text'] == '' or type(next_message['text']) != str:
            continue

        user   = get_user_param(message, machine_name_in_chat=machine_name_in_chat)
        length = get_length_param(data_json[i+1]['text'], tokenizer)
        message_text = re.sub(r"\n", ". ", message['text'])
        new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"

    f.write(new_data)


def load_dataset(train_path, test_path, tokenizer):
    """Creates train and test PyTorch datasets and collate_fn using HuggingFace.

    Parameters
    ----------
    train_path: str
        String containing path to train data
        
    test_path: str
        String containing path to test data

    tokenizer: HuggingFace tokenizer 
        Tokenizer that used to compute the length of the text after encoding.
        For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
    """
    train_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = train_path,
          block_size = 256)
     
    test_dataset = TextDataset(
          tokenizer  = tokenizer,
          file_path  = test_path,
          block_size = 256)   
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )
    return train_dataset, test_dataset, data_collator