File size: 5,189 Bytes
ddf7ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd560bf
ddf7ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9045dc1
 
 
 
 
 
 
 
 
 
 
 
ddf7ac7
 
 
 
 
 
 
 
 
 
9045dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from transformers import AutoTokenizer ,AutoModelForCausalLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
import textwrap
from textwrap3 import wrap

# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param

def remove_duplicates(S):
    S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
    S = S.split()
    result = ""
    for subst in S:
        if subst not in result:
            result += subst+" "
    return result.rstrip()

def removeSigns(S):
    last_index = max(S.rfind("."), S.rfind("!"))
    if last_index >= 0:
        S = S[:last_index+1]
    return S

def prepare_punct():
    torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
                                   'latest_silero_models.yml',
                                   progress=False)

    with open('latest_silero_models.yml', 'r') as yaml_file:
        models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
    model_conf = models.get('te_models').get('latest')

    # Prepare punctuation fix
    model_url = model_conf.get('package')

    model_dir = "downloaded_model"
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, os.path.basename(model_url))

    if not os.path.isfile(model_path):
        torch.hub.download_url_to_file(model_url,
                                       model_path,
                                       progress=True)

    imp = package.PackageImporter(model_path)
    model_punct = imp.load_pickle("te_model", "model")

    return model_punct

def initialize():
    """ Loading the model """
    torch.backends.quantized.engine = 'qnnpack' # Just for the specific machine architecture
    fit_checkpoint = "WarBot"
    tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
    model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
    model_punсt = prepare_punct()
    return (model,tokenizer,model_punсt)

def split_string(string,n=256):
    return [string[i:i+n] for i in range(0, len(string), n)]

def get_response(quote:str,model,tokenizer,model_punct):
    # encode the input, add the eos_token and return a tensor in Pytorch
    user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
                                                  + quote + tokenizer.eos_token, return_tensors="pt")

    chat_history_ids = user_inpit_ids # To be changed

    tokens_count = len(tokenizer.encode(quote))
    if tokens_count < 15:
        no_repeat_ngram_size = 2
    else:
        no_repeat_ngram_size = 1

    output_id = model.generate(
                chat_history_ids,
                num_return_sequences=1, # use for more variants, but have to print [i]
                max_length=200, #512
                no_repeat_ngram_size=no_repeat_ngram_size, #3
                do_sample=True, #True
                top_k=50,#50
                top_p=0.9, #0.9
                temperature = 0.4, # was 0.6, 0 for greedy
                #mask_token_id=tokenizer.mask_token_id,
                eos_token_id=tokenizer.eos_token_id,
                #unk_token_id=tokenizer.unk_token_id,
                pad_token_id=tokenizer.pad_token_id,
                #pad_token_id=tokenizer.eos_token_id,
                #device='cpu'
            )

    response = tokenizer.decode(output_id[0], skip_special_tokens=True)
    response = removeSigns(response)
    response = response.split(quote)[-1]  # Remove the Quote
    response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
                      response)  # Clear the response, remains only alpha-numerical values
    response = remove_duplicates(re.sub(r"\d{4,}", "", response))  # Remove the consequent numbers with 4 or more digits
    response = re.sub(r'\.\.+', '', response) # Remove the "....." thing

    maxLen = 170

    try:
        if len(response)>maxLen: # We shall play with it
            resps = wrap(response,maxLen)
            for i in range(len(resps)):
                resps[i] = model_punct.enhance_text(resps[i], lan='ru')
                response = ''.join(resps)
        else:
            response = model_punct.enhance_text(response, lan='ru')
    except:
	    pass # sometimes the string is getting too long

    response = re.sub(r'[UNK]', '', response)  # Remove the [UNK] thing
    return response

#if __name__ == '__main__':
    #model,tokenizer,model_punct = initialize()
    #quote = "Это хорошо, но глядя на ролик, когда ефиопские толпы в Израиле громят машины и нападают на улице на израильтян - задумаешься, куда все движется"
    #print('please wait...')
    #response = wrap(get_response(quote,model,tokenizer,model_punct),60)
    #for phrase in response:
    #    print(phrase)