cpu / pages /Generator.py
DmitriyMineev's picture
dd
5ba7acb
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
@st.cache_resource
def load_model_and_tokenizer():
model_path = "models/model.pt"
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2"
# Загрузка токенизатора
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
# Загрузка архитектуры модели
model = GPT2LMHeadModel.from_pretrained(tokenizer_path)
# Загрузка весов из файла .pt
device = 'cpu'
model.load_state_dict(torch.load(model_path, map_location=device))
# Перемещение модели на устройство
model.to(device)
model.eval() # Перевод модели в режим оценки
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
# Функция для генерации текста
def generate_conspiracy_theory(prompt, model, tokenizer, max_length=100, temperature=1.0, top_k=50, top_p=0.9):
device = next(model.parameters()).device # Определяем устройство (CPU или GPU)
# Токенизация входного запроса
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Генерация текста
outputs = model.generate(
inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True
)
# Декодирование результатов
theory = tokenizer.decode(outputs[0], skip_special_tokens=True)
return theory
st.title("GPT")
user_input = st.text_area("Введите промт", height=100)
col1, col2, col3 = st.columns(3)
with col1:
length = st.slider("Длина последовательности", 50, 500, 100)
with col2:
num_sequences = st.slider("Число генераций", 1, 5, 1)
with col3:
temperature = st.slider("Температура", 0.1, 1.0, 0.7)
if st.button("Сгенерировать"):
if user_input.strip() == "":
st.error("Пожалуйста, введите текст.")
else:
with st.spinner("Генерация текста..."):
generated_text = generate_conspiracy_theory(
prompt=user_input,
model=model,
tokenizer=tokenizer,
max_length=length,
temperature=temperature
)
st.success("Готово!")
st.write(generated_text)
st.write(next(model.parameters()).device)