import transformers import streamlit as st from transformers import GPT2LMHeadModel, GPT2Tokenizer import numpy as np from PIL import Image import torch import base64 import plotly.express as px df = px.data.iris() @st.cache_data def get_img_as_base64(file): with open(file, "rb") as f: data = f.read() return base64.b64encode(data).decode() page_bg_img = f""" """ st.markdown(page_bg_img, unsafe_allow_html=True) st.title(""" History Mystery """) # Добавление слайдера temp = st.slider("Градус дичи", 1.0, 20.0, 5.0) sen_quan = st.slider(" Длина сгенерированного отрывка", 20, 100, 5) # Загрузка модели и токенизатора # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2') # #Задаем класс модели (уже в streamlit/tg_bot) #@st.cache_resource(allow_output_mutation=True) def load_gpt(): model_GPT = GPT2LMHeadModel.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions=False, output_hidden_states=False, ) tokenizer_GPT = GPT2Tokenizer.from_pretrained( 'sberbank-ai/rugpt3small_based_on_gpt2', output_attentions=False, output_hidden_states=False, ) model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu'))) return model_GPT, tokenizer_GPT #model, tokenizer = load_gpt() # # Вешаем сохраненные веса на нашу модель # Функция для генерации текста def generate_text(model_GPT, tokenizer_GPT, prompt): # Преобразование входной строки в токены input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt') # Генерация текста output = model_GPT.generate(input_ids=input_ids, max_length=100, num_beams=5, do_sample=True, temperature=temp, top_k=60, top_p=0.6, no_repeat_ngram_size=4, num_return_sequences=sen_quan) # Декодирование сгенерированного текста generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True) return generated_text # Streamlit приложение def main(): model_GPT, tokenizer_GPT = load_gpt() st.write(""" # GPT-3 генерация текста """) # Ввод строки пользователем prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на Руси") # # Генерация текста по введенной строке # generated_text = generate_text(prompt) # Создание кнопки "Сгенерировать" generate_button = st.button("За работу!") # Обработка события нажатия кнопки if generate_button: # Вывод сгенерированного текста generated_text = generate_text(model_GPT, tokenizer_GPT, prompt) st.subheader("Продолжение:") st.write(generated_text) if __name__ == "__main__": main()