File size: 7,380 Bytes
05d6778
 
8f30cba
ddf7ac7
 
 
 
 
 
 
05d6778
ddf7ac7
8f30cba
789302e
8f30cba
ddf7ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1e095d
8f30cba
 
 
ddf7ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f30cba
ddf7ac7
 
 
 
 
8f30cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddf7ac7
 
 
 
05d6778
ddf7ac7
05d6778
 
 
8f30cba
05d6778
905f4ca
ddf7ac7
 
 
 
 
 
 
 
 
05d6778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905f4ca
ddf7ac7
 
 
 
 
 
 
 
 
05d6778
 
 
 
9045dc1
 
05d6778
 
 
 
ddf7ac7
05d6778
ddf7ac7
05d6778
 
 
5193efa
05d6778
 
 
 
 
789302e
 
ddf7ac7
 
789302e
05d6778
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Main library for WarBot

from transformers import AutoTokenizer ,AutoModelForCausalLM, AutoModelForSeq2SeqLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
#import textwrap
from textwrap3 import wrap
import replicate #imaging
import chatGPT #This is a costly solution


# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param

def remove_duplicates(S):
    S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
    S = S.split()
    result = ""
    for subst in S:
        if subst not in result:
            result += subst+" "
    return result.rstrip()

def removeSigns(S):
    last_index = max(S.rfind("."), S.rfind("!"))
    if last_index >= 0:
        S = S[:last_index+1]
    return S

def prepare_punct():
    # Prepare the Punctuation Model
    # Important! Enable next line for Unix version (python related):
    torch.backends.quantized.engine = 'qnnpack'

    torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
                                   'latest_silero_models.yml',
                                   progress=False)

    with open('latest_silero_models.yml', 'r') as yaml_file:
        models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
    model_conf = models.get('te_models').get('latest')

    # Prepare punctuation fix
    model_url = model_conf.get('package')

    model_dir = "downloaded_model"
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, os.path.basename(model_url))

    if not os.path.isfile(model_path):
        torch.hub.download_url_to_file(model_url,
                                       model_path,
                                       progress=True)

    imp = package.PackageImporter(model_path)
    model_punct = imp.load_pickle("te_model", "model")

    return model_punct

def initialize():
    # Initializes all the settings
    """ Loading the model """
    fit_checkpoint = "WarBot"
    tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
    model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
    model_punсt = prepare_punct()

    """ Initialize the translational model """
    os.environ['REPLICATE_API_TOKEN'] = '2254e586b1380c49a948fd00d6802d45962492e4'
    translation_model_name = "Helsinki-NLP/opus-mt-ru-en"
    translation_tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
    translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)

    """ Initialize the image model """
    imageModel = replicate.models.get("stability-ai/stable-diffusion")
    imgModel_version = imageModel.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478")

    return (model, tokenizer, model_punсt, translation_model, translation_tokenizer, imgModel_version)

def translate(text:str,translation_model,translation_tokenizer):
    # Translates from Russian to English
    src = "ru"  # source language
    trg = "en"  # target language

    try:
        batch = translation_tokenizer([text], return_tensors="pt")
        generated_ids = translation_model.generate(**batch)
        translated = translation_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    except:
        translated = ""
    return translated

def generate_image(prompt:str, imgModel_version):
    # Generates an image from prompt and returns a url
    prompt = prompt.replace("?","")
    try:
        output_url = imgModel_version.predict(prompt=prompt)[0]
    except:
        output_url = ""

    return output_url

def split_string(string,n=256):
    return [string[i:i+n] for i in range(0, len(string), n)]

def get_response(quote:str,model,tokenizer,model_punct,temperature=0.2):
    # encode the input, add the eos_token and return a tensor in Pytorch
    try:
        user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
                                                      + quote + tokenizer.eos_token, return_tensors="pt")
        # Better to force the lenparameter to be = {2}
    except:
        return "Exception in tokenization" # Exception in tokenization

    chat_history_ids = user_inpit_ids # To be changed

    tokens_count = len(tokenizer.encode(quote))
    if tokens_count < 15:
        no_repeat_ngram_size = 2
    else:
        no_repeat_ngram_size = 1

    try:
        output_id = model.generate(
                    chat_history_ids,
                    num_return_sequences=1, # use for more variants, but have to print [i]
                    max_length=200, #512
                    no_repeat_ngram_size=no_repeat_ngram_size, #3
                    do_sample=True, #True
                    top_k=50,#50
                    top_p=0.9, #0.9
                    temperature = temperature, # was 0.6, 0 for greedy
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    #device='cpu'
                )
    except:
        return "Exception" # Exception in generation

    response = tokenizer.decode(output_id[0], skip_special_tokens=True)
    response = removeSigns(response)
    response = response.split(quote)[-1]  # Remove the Quote
    response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
                      response)  # Clear the response, remains only alpha-numerical values
    response = remove_duplicates(re.sub(r"\d{4,}", "", response))  # Remove the consequent numbers with 4 or more digits
    response = re.sub(r'\.\.+', '', response) # Remove the "....." thing

    if len(response)>200:
        resps = wrap(response,200)
        for i in range(len(resps)):
            try:
                resps[i] = model_punct.enhance_text(resps[i], lan='ru')
                response = ''.join(resps)
            except:
                return "" # Excepion in punctuation
    else:
        response = model_punct.enhance_text(response, lan='ru')

    # Immanent postprocessing of the response
    response = re.sub(r'[UNK]', '', response)  # Remove the [UNK] thing
    response = re.sub(r',+', ',', response)  # Replace multi-commas with single one
    response = re.sub(r'-+', ',', response)  # Replace multi-dashes with single one
    response = re.sub(r'\.\?', '?', response)  # Fix the .? issue
    response = re.sub(r'\,\?', '?', response)  # Fix the ,? issue
    response = re.sub(r'\.\!', '!', response)  # Fix the .! issue
    response = re.sub(r'\.\,', ',', response)  # Fix the ,. issue
    response = re.sub(r'\.\)', '.', response)  # Fix the .) issue
    response = response.replace('[]', '') # Fix the [] issue

    #Experimental:
    response = chatGPT.uGPT(response,quote)
    return response


if __name__ == '__main__':
    """
    quote = "Здравствуй, Жопа, Новый Год, выходи на ёлку!"
    model, tokenizer, model_punct = initialize()
    response = ""
    while not response:
        response = get_response(quote, model, tokenizer, model_punct,temperature=0.2)
    print(response)
    """