ChatStormAI / app.py
DuckyPolice's picture
Update app.py
82d2ad2
raw
history blame contribute delete
No virus
2.75 kB
import numpy
import streamlit as st
import torch
map_location=torch.device('cpu')
st.title( 'Text generation by GPT model' )
st.subheader( 'This application shows the difference in text generation between a rugpt3small model trained on general documents and the same model trained on jokes' )
device = 'cpu' if torch.cuda.is_available() else 'cpu'
# Load the model tokenizer
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
from transformers import GPT2LMHeadModel
# This model is simply loaded
model_init = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
model_init.to(device);
# This is a trained model, we load weights into it
model = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
# Load the model's saved weights and load them into the model
m = torch.load('model.pt', map_location=torch.device('cpu'))
model.load_state_dict(m)
model.to(device);
str = st.text_input( 'Enter 1-4 words of the beginning of the text, and wait a minute' , '' )
# model without additional training
# prompt is a string that will be accepted as input and continued by the model
# tokenize string
prompt = tokenizer.encode(str, return_tensors='pt').to(device)
# out will contain the generation results as a list
out1 = model_init.generate(
# input string
input_ids=prompt,
# maximum length of generated sequence
max_length=150,
# num_beams
num_beams=5,
# apply sampling
do_sample=True,
# apply temperature
temperature=1.,
# top words by probability
top_k=50,
# top words by total probability
top_p=0.6,
# how much (try) not to repeat n_gram in a row
no_repeat_ngram_size=3,
# how many generations to return
num_return_sequences=1,
).cpu().numpy() #).numpy()
st.write('\n------------------\n')
st.subheader( 'Texts on the model trained by documents of all subjects:' )
# out contains results
# decoding and printing
n = 0
for out_ in out1:
n += 1
st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
st.write('\n------------------\n')
# print(tokenizer.decode(out_))
# retrained model
with torch.inference_mode():
# prompt = 'Man asks the waiter'
# prompt = tokenizer.encode(str, return_tensors='pt')
out2 = model.generate(
input_ids=prompt,
max_length=150,
num_beams=1,
do_sample=True,
temperature=1.,
top_k=5,
top_p=0.6,
no_repeat_ngram_size=2,
num_return_sequences=3,
).cpu().numpy() #).cpu().numpy()