Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import time | |
def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature): | |
generated_texts = [] | |
for _ in range(num_generations): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
num_return_sequences=1 | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
generated_texts.append(generated_text) | |
return generated_texts | |
button_style = """ | |
<style> | |
.center-align { | |
display: flex; | |
justify-content: center; | |
</style> | |
""" | |
DEVICE = 'cpu' | |
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2" | |
model = torch.load('srcs/gpt_weights.pth').to(DEVICE) | |
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) | |
st.markdown(""" | |
<style> | |
section[data-testid="stSidebar"][aria-expanded="true"]{ | |
display: none; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.write("## Text generator") | |
st.page_link("app.py", label="Home", icon='🏠') | |
st.markdown( | |
""" | |
This streamlit-app can generate text using your prompt | |
""" | |
) | |
# Ввод пользовательского prompt | |
prompt = st.text_area("Enter your prompt:") | |
# Параметры генерации | |
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10) | |
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1) | |
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1) | |
try: | |
if st.button("Generate text"): | |
start_time = time.time() | |
generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature) | |
end_time = time.time() | |
st.subheader("Сгенерированный текст:") | |
for i, text in enumerate(generated_texts, start=1): | |
st.write(f"Генерация {i}:\n{text}") | |
generation_time = end_time - start_time | |
st.write(f"\nВремя генерации: {generation_time:.2f} секунд") | |
st.markdown(button_style, unsafe_allow_html=True) # Применяем стиль к кнопке | |
st.markdown( | |
""" | |
<style> | |
div[data-baseweb="textarea"] { | |
border: 2px solid #3498db; /* Цвет границы */ | |
border-radius: 5px; /* Закругленные углы */ | |
background-color: #ecf0f1; /* Цвет фона */ | |
padding: 10px; /* Поля вокруг текстового поля */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
except: | |
st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ') |