A1ex1 commited on
Commit
da65de1
1 Parent(s): 1b83f09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
 
5
  st.title('Генерация текста GPT-моделью')
6
  st.subheader('Это приложение показывает разницу в генерации текста моделью rugpt3small, обученной на документах общей тематики и этой же моделью, дообученной на анекдотах')
 
 
 
7
  # Загружаем токенайзер модели
8
  from transformers import GPT2Tokenizer
9
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
@@ -16,17 +19,17 @@ model_init = GPT2LMHeadModel.from_pretrained(
16
  output_attentions = False,
17
  output_hidden_states = False,
18
  )
 
 
 
 
 
 
 
19
 
20
- # Это обученная модель, в нее загружаем веса
21
- model = GPT2LMHeadModel.from_pretrained(
22
- 'sberbank-ai/rugpt3small_based_on_gpt2',
23
- output_attentions = False,
24
- output_hidden_states = False,
25
- )
26
-
27
- m = torch.load('model.pt')
28
- model.load_state_dict(m)
29
-
30
 
31
  str = st.text_input('Введите 1-4 слова начала текста, и подождите минутку', 'Мужик спрашивает у официанта')
32
 
@@ -34,7 +37,7 @@ str = st.text_input('Введите 1-4 слова начала текста, и
34
  # prompt – строка, которую примет на вход и продолжит модель
35
 
36
  # токенизируем строку
37
- prompt = tokenizer.encode(str, return_tensors='pt')
38
 
39
  # out будет содержать результаты генерации в виде списка
40
  out1 = model_init.generate(
@@ -56,7 +59,7 @@ out1 = model_init.generate(
56
  no_repeat_ngram_size=3,
57
  # сколько вернуть генераций
58
  num_return_sequences=3,
59
- ).numpy() #).cpu().numpy()
60
 
61
  st.write('\n------------------\n')
62
  st.subheader('Тексты на модели, обученной документами всех тематик:')
@@ -70,26 +73,26 @@ for out_ in out1:
70
  # print(tokenizer.decode(out_))
71
 
72
 
73
- # дообученная модель
74
- with torch.inference_mode():
75
- # prompt = 'Мужик спрашивает официанта'
76
- # prompt = tokenizer.encode(str, return_tensors='pt')
77
- out2 = model.generate(
78
- input_ids=prompt,
79
- max_length=150,
80
- num_beams=1,
81
- do_sample=True,
82
- temperature=1.,
83
- top_k=5,
84
- top_p=0.6,
85
- no_repeat_ngram_size=2,
86
- num_return_sequences=3,
87
- ).numpy() #).cpu().numpy()
88
 
89
- st.subheader('Тексты на модели, обученной документами всех тематик и дообученной анекдотами:')
90
- n = 0
91
- for out_ in out2:
92
- n += 1
93
- st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
94
- # print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
95
- st.write('\n------------------\n')
 
4
 
5
  st.title('Генерация текста GPT-моделью')
6
  st.subheader('Это приложение показывает разницу в генерации текста моделью rugpt3small, обученной на документах общей тематики и этой же моделью, дообученной на анекдотах')
7
+
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+
10
  # Загружаем токенайзер модели
11
  from transformers import GPT2Tokenizer
12
  tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
 
19
  output_attentions = False,
20
  output_hidden_states = False,
21
  )
22
+ model_init.to(device);
23
+ # # Это обученная модель, в нее загружаем веса
24
+ # model = GPT2LMHeadModel.from_pretrained(
25
+ # 'sberbank-ai/rugpt3small_based_on_gpt2',
26
+ # output_attentions = False,
27
+ # output_hidden_states = False,
28
+ # )
29
 
30
+ # m = torch.load('model.pt')
31
+ # model.load_state_dict(m)
32
+ # model.to(device);
 
 
 
 
 
 
 
33
 
34
  str = st.text_input('Введите 1-4 слова начала текста, и подождите минутку', 'Мужик спрашивает у официанта')
35
 
 
37
  # prompt – строка, которую примет на вход и продолжит модель
38
 
39
  # токенизируем строку
40
+ prompt = tokenizer.encode(str, return_tensors='pt').to(device)
41
 
42
  # out будет содержать результаты генерации в виде списка
43
  out1 = model_init.generate(
 
59
  no_repeat_ngram_size=3,
60
  # сколько вернуть генераций
61
  num_return_sequences=3,
62
+ ).cpu().numpy() #).numpy()
63
 
64
  st.write('\n------------------\n')
65
  st.subheader('Тексты на модели, обученной документами всех тематик:')
 
73
  # print(tokenizer.decode(out_))
74
 
75
 
76
+ # # дообученная модель
77
+ # with torch.inference_mode():
78
+ # # prompt = 'Мужик спрашивает официанта'
79
+ # # prompt = tokenizer.encode(str, return_tensors='pt')
80
+ # out2 = model.generate(
81
+ # input_ids=prompt,
82
+ # max_length=150,
83
+ # num_beams=1,
84
+ # do_sample=True,
85
+ # temperature=1.,
86
+ # top_k=5,
87
+ # top_p=0.6,
88
+ # no_repeat_ngram_size=2,
89
+ # num_return_sequences=3,
90
+ # ).numpy() #).cpu().numpy()
91
 
92
+ # st.subheader('Тексты на модели, обученной документами всех тематик и дообученной анекдотами:')
93
+ # n = 0
94
+ # for out_ in out2:
95
+ # n += 1
96
+ # st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
97
+ # # print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
98
+ # st.write('\n------------------\n')