File size: 2,353 Bytes
7d6f77f
 
 
 
518485c
da9ee26
7d6f77f
 
47d0a4a
7d6f77f
 
 
 
 
 
 
 
a16dba0
bb72c45
 
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
b86439f
7d6f77f
 
 
 
 
19d554b
7d6f77f
da9ee26
 
 
 
 
7d6f77f
 
 
b86439f
7d6f77f
 
 
ded5ed0
da9ee26
 
4bd4566
 
 
 
 
 
 
 
 
ded5ed0
da9ee26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import transformers
import torch
import tokenizers
import streamlit as st
import re
from PIL import Image


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=max_length + length_of_prompt,
                             )

    return list(map(tokenizer.decode, out))[0]


model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_30epochs_1bs.bin')

# st.title("NeuroKorzh")
# st.markdown("<img width=400px src='https://the-flow.ru/uploads/images/resize/830x0/adaptiveResize/05/06/06/42/25/8c7405840cd7.jpg'>",
#              unsafe_allow_html=True)
image = Image.open('korzh.jpg')
st.image(image, caption='NeuroKorzh')

st.markdown("\n")

text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
button = st.button('Go')

if button:
    #try:
    with st.spinner("Generation in progress"):
        result = predict(text, model, tokenizer)
    
    #st.subheader('Max Korzh:')
    #lines = result.split('\n')
    #for line in lines:
    #    st.write(line)
    
    lines = result.replace('\n', '\n\n')
    st.write(lines)
    
    #except Exception:
    #    st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")