NLP_LSTM_team / models /generate_text.py
MossyHead's picture
Upload 2 files (#1)
23e5486 verified
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from safetensors import safe_open
# Функция для загрузки весов модели из файла safetensors
def load_model_weights(model, safetensors_path):
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key in model.state_dict():
try:
model.state_dict()[key].copy_(f.get_tensor(key))
except RuntimeError as e:
print(f"Error copying key {key}: {e}")
return model
# Загрузка токенизатора GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
# Добавление специального токена для заполнения
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Загрузка модели GPT-2
model = GPT2LMHeadModel.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")
# Изменение размера токенов в модели после добавления специального токена
model.resize_token_embeddings(len(tokenizer))
# Загрузка весов из safetensors
model = load_model_weights(model, "models/model_lenin_zametki.safetensors")
# Streamlit приложение
def generate_text(prompt, length, num_generations, temperature, top_k, top_p):
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = []
for _ in range(num_generations):
output = model.generate(
inputs,
max_length=length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
num_return_sequences=1
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
outputs.append(text)
return outputs