WarBot / WarBot.py
kertser's picture
Upload WarBot.py
9045dc1
raw
history blame
5.19 kB
from transformers import AutoTokenizer ,AutoModelForCausalLM
import re
# Speller and punctuation:
import os
import yaml
import torch
from torch import package
# not very necessary
import textwrap
from textwrap3 import wrap
# util function to get expected len after tokenizing
def get_length_param(text: str, tokenizer) -> str:
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
def remove_duplicates(S):
S = re.sub(r'[a-zA-Z]+', '', S) #Remove english
S = S.split()
result = ""
for subst in S:
if subst not in result:
result += subst+" "
return result.rstrip()
def removeSigns(S):
last_index = max(S.rfind("."), S.rfind("!"))
if last_index >= 0:
S = S[:last_index+1]
return S
def prepare_punct():
torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
'latest_silero_models.yml',
progress=False)
with open('latest_silero_models.yml', 'r') as yaml_file:
models = yaml.load(yaml_file, Loader=yaml.SafeLoader)
model_conf = models.get('te_models').get('latest')
# Prepare punctuation fix
model_url = model_conf.get('package')
model_dir = "downloaded_model"
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, os.path.basename(model_url))
if not os.path.isfile(model_path):
torch.hub.download_url_to_file(model_url,
model_path,
progress=True)
imp = package.PackageImporter(model_path)
model_punct = imp.load_pickle("te_model", "model")
return model_punct
def initialize():
""" Loading the model """
torch.backends.quantized.engine = 'qnnpack' # Just for the specific machine architecture
fit_checkpoint = "WarBot"
tokenizer = AutoTokenizer.from_pretrained(fit_checkpoint)
model = AutoModelForCausalLM.from_pretrained(fit_checkpoint)
model_punсt = prepare_punct()
return (model,tokenizer,model_punсt)
def split_string(string,n=256):
return [string[i:i+n] for i in range(0, len(string), n)]
def get_response(quote:str,model,tokenizer,model_punct):
# encode the input, add the eos_token and return a tensor in Pytorch
user_inpit_ids = tokenizer.encode(f"|0|{get_length_param(quote, tokenizer)}|" \
+ quote + tokenizer.eos_token, return_tensors="pt")
chat_history_ids = user_inpit_ids # To be changed
tokens_count = len(tokenizer.encode(quote))
if tokens_count < 15:
no_repeat_ngram_size = 2
else:
no_repeat_ngram_size = 1
output_id = model.generate(
chat_history_ids,
num_return_sequences=1, # use for more variants, but have to print [i]
max_length=200, #512
no_repeat_ngram_size=no_repeat_ngram_size, #3
do_sample=True, #True
top_k=50,#50
top_p=0.9, #0.9
temperature = 0.4, # was 0.6, 0 for greedy
#mask_token_id=tokenizer.mask_token_id,
eos_token_id=tokenizer.eos_token_id,
#unk_token_id=tokenizer.unk_token_id,
pad_token_id=tokenizer.pad_token_id,
#pad_token_id=tokenizer.eos_token_id,
#device='cpu'
)
response = tokenizer.decode(output_id[0], skip_special_tokens=True)
response = removeSigns(response)
response = response.split(quote)[-1] # Remove the Quote
response = re.sub(r'[^0-9А-Яа-яЁёa-zA-z;., !()/\-+:?]', '',
response) # Clear the response, remains only alpha-numerical values
response = remove_duplicates(re.sub(r"\d{4,}", "", response)) # Remove the consequent numbers with 4 or more digits
response = re.sub(r'\.\.+', '', response) # Remove the "....." thing
maxLen = 170
try:
if len(response)>maxLen: # We shall play with it
resps = wrap(response,maxLen)
for i in range(len(resps)):
resps[i] = model_punct.enhance_text(resps[i], lan='ru')
response = ''.join(resps)
else:
response = model_punct.enhance_text(response, lan='ru')
except:
pass # sometimes the string is getting too long
response = re.sub(r'[UNK]', '', response) # Remove the [UNK] thing
return response
#if __name__ == '__main__':
#model,tokenizer,model_punct = initialize()
#quote = "Это хорошо, но глядя на ролик, когда ефиопские толпы в Израиле громят машины и нападают на улице на израильтян - задумаешься, куда все движется"
#print('please wait...')
#response = wrap(get_response(quote,model,tokenizer,model_punct),60)
#for phrase in response:
# print(phrase)