import base64 import streamlit as st import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer from model.funcs import execution_time def get_base64(file_path): with open(file_path, "rb") as file: base64_bytes = base64.b64encode(file.read()) base64_string = base64_bytes.decode("utf-8") return base64_string def set_background(png_file): bin_str = get_base64(png_file) page_bg_img = ( """ """ % bin_str ) st.markdown(page_bg_img, unsafe_allow_html=True) set_background("text_generation.png") @st.cache_data def load_model(): model_path = "17/" model_name = "sberbank-ai/rugpt3small_based_on_gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_path) return tokenizer, model tokenizer, model = load_model() @execution_time def generate_text( prompt, num_beams=2, temperature=1.5, top_p=0.9, top_k=3, max_length=150 ): prompt = tokenizer.encode(prompt, return_tensors="pt") model.eval() with torch.no_grad(): out = model.generate( prompt, do_sample=True, num_beams=num_beams, temperature=temperature, top_p=top_p, top_k=top_k, max_length=max_length, ) out = list(map(tokenizer.decode, out))[0] return out with st.sidebar: num_beams = st.slider("Number of Beams", min_value=1, max_value=5, value=2) temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.5) top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9) top_k = st.slider("Top-k", min_value=1, max_value=10, value=3) max_length = st.slider("Maximum Length", min_value=20, max_value=300, value=150) styled_text = """ """ st.markdown(styled_text, unsafe_allow_html=True) prompt = st.text_input( "Ask a question", key="question_input", placeholder="Type here...", type="default", value="", ) generate = st.button("Generate", key="generate_button") if generate: if not prompt: st.write("42") else: generated_text = generate_text( prompt, num_beams, temperature, top_p, top_k, max_length ) paragraphs = generated_text.split("\n") styled_paragraphs = [ f'
{paragraph}
' for paragraph in paragraphs ] styled_generated_text = " ".join(styled_paragraphs) st.markdown(styled_generated_text, unsafe_allow_html=True)