File size: 3,308 Bytes
6c6aac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time

def generate_text(model, tokenizer, prompt, max_length, num_generations, temperature):
    generated_texts = []

    for _ in range(num_generations):
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        output = model.generate(
            input_ids,
            max_length=max_length,
            temperature=temperature,
            num_return_sequences=1
        )
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        generated_texts.append(generated_text)

    return generated_texts

button_style = """

    <style>

    .center-align {

        display: flex;

        justify-content: center;

    

    </style>

"""

DEVICE = 'cpu' 

# Загрузка пользовательской модели и токенизатора (замените на свои пути и модель)
# model_path = "sberbank-ai/rugpt3small_based_on_gpt2"
# tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"

# model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
# tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

st.markdown("""

    <style>

        section[data-testid="stSidebar"][aria-expanded="true"]{

            display: none;

        }

    </style>

    """, unsafe_allow_html=True)

st.write("## Text generator")
st.page_link("main.py", label="Home", icon='🏠')
st.markdown(
        """

        This streamlit-app can generate text using your prompt 

    """
)
# Ввод пользовательского prompt
prompt = st.text_area("Enter your prompt:")

# Параметры генерации
max_length = st.slider("Max length of generated text:", min_value=10, max_value=500, value=100, step=10)
num_generations = st.slider("Number of generations:", min_value=1, max_value=10, value=3, step=1)
temperature = st.slider("Temperature:", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
try:
    if st.button("Generate text"):
        start_time = time.time()
        generated_texts = generate_text(model, tokenizer, prompt, max_length, num_generations, temperature)
        end_time = time.time()

        st.subheader("Сгенерированный текст:")
        for i, text in enumerate(generated_texts, start=1):
            st.write(f"Генерация {i}:\n{text}")

        generation_time = end_time - start_time
        st.write(f"\nВремя генерации: {generation_time:.2f} секунд")

    st.markdown(button_style, unsafe_allow_html=True)  # Применяем стиль к кнопке
    st.markdown(
        """

        <style>

            div[data-baseweb="textarea"] {

                border: 2px solid #3498db;  /* Цвет границы */

                border-radius: 5px;  /* Закругленные углы */

                background-color: #ecf0f1;  /* Цвет фона */

                padding: 10px;  /* Поля вокруг текстового поля */

            }

        </style>

        """,
        unsafe_allow_html=True,
    )
except:
    st.write('Модель в разработке ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ( ノ ゚ー゚)ノ')