nlp_project / pages /task3.py
jeydipak's picture
num file
c9d6cdd verified
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import torch
def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
generated_texts = []
for _ in range(num_generations):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = model.generate(
input_ids,
max_length=max_length,
temperature=temperature,
num_return_sequences=1,
do_sample=True)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_texts.append(generated_text)
return generated_texts
button_style = """
<style>
.center-align {
display: flex;
justify-content: center;
</style>
"""
DEVICE = 'cpu'
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"
model = torch.load('srcs/gpt_weights.pth', map_location=torch.device('cpu'))
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
st.markdown("""
<style>
section[data-testid="stSidebar"][aria-expanded="true"]{
display: none;
}
</style>
""", unsafe_allow_html=True)
st.write("## Text generator")
st.page_link("app.py", label="Home", icon='🏠')
st.markdown(
"""
This streamlit-app can generate text using your prompt
"""
)
# Ввод пользовательского prompt
prompt = st.text_area("Enter your prompt:")
# Параметры генерации
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
try:
if st.button("Generate text"):
start_time = time.time()
generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
end_time = time.time()
st.subheader("Сгенерированный текст:")
for i, text in enumerate(generated_texts, start=1):
st.write(f"Генерация {i}:\n{text}")
generation_time = end_time - start_time
st.write(f"\nВремя генерации: {generation_time:.2f} секунд")
st.markdown(button_style, unsafe_allow_html=True) # Применяем стиль к кнопке
st.markdown(
"""
<style>
div[data-baseweb="textarea"] {
border: 2px solid #3498db; /* Цвет границы */
border-radius: 5px; /* Закругленные углы */
background-color: #ecf0f1; /* Цвет фона */
padding: 10px; /* Поля вокруг текстового поля */
}
</style>
""",
unsafe_allow_html=True,
)
except:
st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ')