SaviAnna commited on
Commit
ac6cd97
1 Parent(s): 3712cee

Delete pages/History.py

Browse files
Files changed (1) hide show
  1. pages/History.py +0 -72
pages/History.py DELETED
@@ -1,72 +0,0 @@
1
- import transformers
2
- import streamlit as st
3
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
- import numpy as np
5
- from PIL import Image
6
- import torch
7
-
8
- st.title("""
9
- History Mistery
10
- """)
11
- # image = Image.open('data-scins.jpeg')
12
-
13
- # st.image(image, caption='Current mood')
14
- # Добавление слайдера
15
- temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
16
- max_length = st.slider("Длина сгенерированного отрывка",40, 120, 40)
17
- # Загрузка модели и токенизатора
18
- # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
19
- # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
20
- # #Задаем класс модели (уже в streamlit/tg_bot)
21
- model = GPT2LMHeadModel.from_pretrained(
22
- 'sberbank-ai/rugpt3small_based_on_gpt2',
23
- output_attentions = False,
24
- output_hidden_states = False,
25
- )
26
- tokenizer = GPT2Tokenizer.from_pretrained(
27
- 'sberbank-ai/rugpt3small_based_on_gpt2',
28
- output_attentions = False,
29
- output_hidden_states = False,
30
- )
31
-
32
- # # Вешаем сохраненные веса на нашу модель
33
- model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
34
- # Функция для генерации текста
35
- def generate_text(prompt):
36
- # Преобразование входной строки в токены
37
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
38
-
39
- # Генерация текста
40
- output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
41
- temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
42
- num_return_sequences=3)
43
-
44
- # Декодирование сгенерированного текста
45
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
46
-
47
- return generated_text
48
-
49
- # Streamlit приложение
50
- def main():
51
- st.write("""
52
- # GPT-3 генерация текста
53
- """)
54
-
55
- # Ввод строки пользователем
56
- prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
57
-
58
- # # Генерация текста по введенной строке
59
- # generated_text = generate_text(prompt)
60
- # Создание кнопки "Сгенерировать"
61
- generate_button = st.button("За работу!")
62
- # Обработка события нажатия кнопки
63
- if generate_button:
64
- # Вывод сгенерированного текста
65
- generated_text = generate_text(prompt)
66
- st.subheader("Продолжение:")
67
- st.write(generated_text)
68
-
69
-
70
-
71
- if __name__ == "__main__":
72
- main()