File size: 5,490 Bytes
05d6778
 
ddf7ac7
 
 
 
 
 
 
 
05d6778
ddf7ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05d6778
ddf7ac7
05d6778
 
 
 
 
ddf7ac7
 
 
 
 
 
 
 
 
05d6778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddf7ac7
 
 
 
 
 
 
 
 
05d6778
 
 
 
9045dc1
 
05d6778
 
 
 
ddf7ac7
05d6778
ddf7ac7
05d6778
 
 
 
 
 
 
 
ddf7ac7
 
05d6778
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Main library for WarBot

from transformers import AutoTokenizer ,AutoModelForCausalLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
#import textwrap
from textwrap3 import wrap

# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param

def remove_duplicates(S):
    S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
    S = S.split()
    result = ""
    for subst in S:
        if subst not in result:
            result += subst+" "
    return result.rstrip()

def removeSigns(S):
    last_index = max(S.rfind("."), S.rfind("!"))
    if last_index >= 0:
        S = S[:last_index+1]
    return S

def prepare_punct():
    torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
                                   'latest_silero_models.yml',
                                   progress=False)

    with open('latest_silero_models.yml', 'r') as yaml_file:
        models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
    model_conf = models.get('te_models').get('latest')

    # Prepare punctuation fix
    model_url = model_conf.get('package')

    model_dir = "downloaded_model"
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, os.path.basename(model_url))

    if not os.path.isfile(model_path):
        torch.hub.download_url_to_file(model_url,
                                       model_path,
                                       progress=True)

    imp = package.PackageImporter(model_path)
    model_punct = imp.load_pickle("te_model", "model")

    return model_punct

def initialize():
    """ Loading the model """
    fit_checkpoint = "WarBot"
    tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
    model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
    model_punсt = prepare_punct()
    return (model,tokenizer,model_punсt)

def split_string(string,n=256):
    return [string[i:i+n] for i in range(0, len(string), n)]

def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
    # encode the input, add the eos_token and return a tensor in Pytorch
    try:
        user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
                                                      + quote + tokenizer.eos_token, return_tensors="pt")
    except:
        return "" # Exception in tokenization

    chat_history_ids = user_inpit_ids # To be changed

    tokens_count = len(tokenizer.encode(quote))
    if tokens_count < 15:
        no_repeat_ngram_size = 2
    else:
        no_repeat_ngram_size = 1

    try:
        output_id = model.generate(
                    chat_history_ids,
                    num_return_sequences=1, # use for more variants, but have to print [i]
                    max_length=200, #512
                    no_repeat_ngram_size=no_repeat_ngram_size, #3
                    do_sample=True, #True
                    top_k=50,#50
                    top_p=0.9, #0.9
                    temperature = temperature, # was 0.6, 0 for greedy
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    #device='cpu'
                )
    except:
        return "" # Exception in generation

    response = tokenizer.decode(output_id[0], skip_special_tokens=True)
    response = removeSigns(response)
    response = response.split(quote)[-1]  # Remove the Quote
    response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
                      response)  # Clear the response, remains only alpha-numerical values
    response = remove_duplicates(re.sub(r"\d{4,}", "", response))  # Remove the consequent numbers with 4 or more digits
    response = re.sub(r'\.\.+', '', response) # Remove the "....." thing

    if len(response)>200:
        resps = wrap(response,200)
        for i in range(len(resps)):
            try:
                resps[i] = model_punct.enhance_text(resps[i], lan='ru')
                response = ''.join(resps)
            except:
                return "" # Excepion in punctuation
    else:
        response = model_punct.enhance_text(response, lan='ru')

    # Immanent postprocessing of the response
    response = re.sub(r'[UNK]', '', response)  # Remove the [UNK] thing
    response = re.sub(r',+', ',', response)  # Replace multi-commas with single one
    response = re.sub(r'-+', ',', response)  # Replace multi-dashes with single one
    response = re.sub(r'\.\?', '?', response)  # Fix the .? issue
    response = re.sub(r'\.\!', '!', response)  # Fix the .! issue
    response = re.sub(r'\.\,', ',', response)  # Fix the ,. issue
    response = re.sub(r'\.\)', '.', response)  # Fix the .) issue
    response = response.replace('[]', '') # Fix the [] issue

    return response

if __name__ == '__main__':
    """
    quote = "Здравствуй, Жопа, Новый Год, выходи на ёлку!"
    model, tokenizer, model_punct = initialize()
    response = ""
    while not response:
        response = get_response(quote, model, tokenizer, model_punct,temperature=0.2)
    print(response)
    """