Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import torch | |
def load_model_and_tokenizer(): | |
model_path = "models/model.pt" | |
tokenizer_path = "sberbank-ai/rugpt3small_based_on_gpt2" | |
# Загрузка токенизатора | |
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) | |
# Загрузка архитектуры модели | |
model = GPT2LMHeadModel.from_pretrained(tokenizer_path) | |
# Загрузка весов из файла .pt | |
device = 'cpu' | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
# Перемещение модели на устройство | |
model.to(device) | |
model.eval() # Перевод модели в режим оценки | |
return model, tokenizer | |
model, tokenizer = load_model_and_tokenizer() | |
# Функция для генерации текста | |
def generate_conspiracy_theory(prompt, model, tokenizer, max_length=100, temperature=1.0, top_k=50, top_p=0.9): | |
device = next(model.parameters()).device # Определяем устройство (CPU или GPU) | |
# Токенизация входного запроса | |
inputs = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
# Генерация текста | |
outputs = model.generate( | |
inputs, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Декодирование результатов | |
theory = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return theory | |
st.title("GPT") | |
user_input = st.text_area("Введите промт", height=100) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
length = st.slider("Длина последовательности", 50, 500, 100) | |
with col2: | |
num_sequences = st.slider("Число генераций", 1, 5, 1) | |
with col3: | |
temperature = st.slider("Температура", 0.1, 1.0, 0.7) | |
if st.button("Сгенерировать"): | |
if user_input.strip() == "": | |
st.error("Пожалуйста, введите текст.") | |
else: | |
with st.spinner("Генерация текста..."): | |
generated_text = generate_conspiracy_theory( | |
prompt=user_input, | |
model=model, | |
tokenizer=tokenizer, | |
max_length=length, | |
temperature=temperature | |
) | |
st.success("Готово!") | |
st.write(generated_text) | |
st.write(next(model.parameters()).device) |