Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 3,054 Bytes
			
			| bbc9b75 4235e44 bbc9b75 c9d6cdd bbc9b75 dfcdd52 db07551 dfcdd52 bbc9b75 6c6aac8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import torch
def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
    generated_texts = []
    for _ in range(num_generations):
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            num_return_sequences=1,
            do_sample=True)
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_texts.append(generated_text)
    return generated_texts
button_style = """
    <style>
    .center-align {
        display: flex;
        justify-content: center;
    
    </style>
"""
DEVICE = 'cpu' 
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"
model = torch.load('srcs/gpt_weights.pth', map_location=torch.device('cpu'))
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
st.markdown("""
    <style>
        section[data-testid="stSidebar"][aria-expanded="true"]{
            display: none;
        }
    </style>
    """, unsafe_allow_html=True)
st.write("## Text generator")
st.page_link("app.py", label="Home", icon='🏠')
st.markdown(
        """
        This streamlit-app can generate text using your prompt 
    """
)
# Ввод пользовательского prompt
prompt = st.text_area("Enter your prompt:")
# Параметры генерации
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
try:
    if st.button("Generate text"):
        start_time = time.time()
        generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
        end_time = time.time()
        st.subheader("Сгенерированный текст:")
        for i, text in enumerate(generated_texts, start=1):
            st.write(f"Генерация {i}:\n{text}")
        generation_time = end_time - start_time
        st.write(f"\nВремя генерации: {generation_time:.2f} секунд")
    st.markdown(button_style, unsafe_allow_html=True)  # Применяем стиль к кнопке
    st.markdown(
        """
        <style>
            div[data-baseweb="textarea"] {
                border: 2px solid #3498db;  /* Цвет границы */
                border-radius: 5px;  /* Закругленные углы */
                background-color: #ecf0f1;  /* Цвет фона */
                padding: 10px;  /* Поля вокруг текстового поля */
            }
        </style>
        """,
        unsafe_allow_html=True,
    )
except:
    st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ')   |