File size: 3,054 Bytes
bbc9b75
 
 
4235e44
bbc9b75
 
 
 
 
 
 
 
 
 
c9d6cdd
 
bbc9b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfcdd52
 
db07551
dfcdd52
bbc9b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c6aac8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
import torch

def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
    generated_texts = []

    for _ in range(num_generations):
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            num_return_sequences=1,
            do_sample=True)
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_texts.append(generated_text)

    return generated_texts

button_style = """
    <style>
    .center-align {
        display: flex;
        justify-content: center;
    
    </style>
"""

DEVICE = 'cpu' 


tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"

model = torch.load('srcs/gpt_weights.pth', map_location=torch.device('cpu'))
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

st.markdown("""
    <style>
        section[data-testid="stSidebar"][aria-expanded="true"]{
            display: none;
        }
    </style>
    """, unsafe_allow_html=True)

st.write("## Text generator")
st.page_link("app.py", label="Home", icon='🏠')
st.markdown(
        """
        This streamlit-app can generate text using your prompt 
    """
)
# Ввод пользовательского prompt
prompt = st.text_area("Enter your prompt:")

# Параметры генерации
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
try:
    if st.button("Generate text"):
        start_time = time.time()
        generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
        end_time = time.time()

        st.subheader("Сгенерированный текст:")
        for i, text in enumerate(generated_texts, start=1):
            st.write(f"Генерация {i}:\n{text}")

        generation_time = end_time - start_time
        st.write(f"\nВремя генерации: {generation_time:.2f} секунд")

    st.markdown(button_style, unsafe_allow_html=True)  # Применяем стиль к кнопке
    st.markdown(
        """
        <style>
            div[data-baseweb="textarea"] {
                border: 2px solid #3498db;  /* Цвет границы */
                border-radius: 5px;  /* Закругленные углы */
                background-color: #ecf0f1;  /* Цвет фона */
                padding: 10px;  /* Поля вокруг текстового поля */
            }
        </style>
        """,
        unsafe_allow_html=True,
    )
except:
    st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ')