A1ex1 commited on
Commit
1b83f09
·
1 Parent(s): 80ec54a

Add application file

Browse files
Files changed (5) hide show
  1. anekdoty.txt +0 -0
  2. app.py +95 -0
  3. lerning.py +230 -0
  4. model.pt +3 -0
  5. requirements.txt +58 -0
anekdoty.txt ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import streamlit as st
3
+ import torch
4
+
5
+ st.title('Генерация текста GPT-моделью')
6
+ st.subheader('Это приложение показывает разницу в генерации текста моделью rugpt3small, обученной на документах общей тематики и этой же моделью, дообученной на анекдотах')
7
+ # Загружаем токенайзер модели
8
+ from transformers import GPT2Tokenizer
9
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
10
+
11
+ from transformers import GPT2LMHeadModel
12
+
13
+ # Эту модель просто подгружаем
14
+ model_init = GPT2LMHeadModel.from_pretrained(
15
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
16
+ output_attentions = False,
17
+ output_hidden_states = False,
18
+ )
19
+
20
+ # Это обученная модель, в нее загружаем веса
21
+ model = GPT2LMHeadModel.from_pretrained(
22
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
23
+ output_attentions = False,
24
+ output_hidden_states = False,
25
+ )
26
+
27
+ m = torch.load('model.pt')
28
+ model.load_state_dict(m)
29
+
30
+
31
+ str = st.text_input('Введите 1-4 слова начала текста, и подождите минутку', 'Мужик спрашивает у официанта')
32
+
33
+ # модель без дообучения
34
+ # prompt – строка, которую примет на вход и продолжит модель
35
+
36
+ # токенизируем строку
37
+ prompt = tokenizer.encode(str, return_tensors='pt')
38
+
39
+ # out будет содержать результаты генерации в виде списка
40
+ out1 = model_init.generate(
41
+ # входная строка
42
+ input_ids=prompt,
43
+ # максимальная длина генерируемой последовательности
44
+ max_length=150,
45
+ # num_beams
46
+ num_beams=5,
47
+ # применяем сэмплирование
48
+ do_sample=True,
49
+ # применяем температуру
50
+ temperature=1.,
51
+ # топ слов по вероятности
52
+ top_k=50,
53
+ # топ слов по суммарной вероятности
54
+ top_p=0.6,
55
+ # сколько (постараться) не повторять n_gram подряд
56
+ no_repeat_ngram_size=3,
57
+ # сколько вернуть генераций
58
+ num_return_sequences=3,
59
+ ).numpy() #).cpu().numpy()
60
+
61
+ st.write('\n------------------\n')
62
+ st.subheader('Тексты на модели, обученной документами всех тематик:')
63
+ # out содержит результаты
64
+ # декодируем и печатаем
65
+ n = 0
66
+ for out_ in out1:
67
+ n += 1
68
+ st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
69
+ st.write('\n------------------\n')
70
+ # print(tokenizer.decode(out_))
71
+
72
+
73
+ # дообученная модель
74
+ with torch.inference_mode():
75
+ # prompt = 'Мужик спрашивает официанта'
76
+ # prompt = tokenizer.encode(str, return_tensors='pt')
77
+ out2 = model.generate(
78
+ input_ids=prompt,
79
+ max_length=150,
80
+ num_beams=1,
81
+ do_sample=True,
82
+ temperature=1.,
83
+ top_k=5,
84
+ top_p=0.6,
85
+ no_repeat_ngram_size=2,
86
+ num_return_sequences=3,
87
+ ).numpy() #).cpu().numpy()
88
+
89
+ st.subheader('Тексты на модели, обученной документами всех тематик и дообученной анекдотами:')
90
+ n = 0
91
+ for out_ in out2:
92
+ n += 1
93
+ st.write(tokenizer.decode(out_).rpartition('.')[0],'.')
94
+ # print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
95
+ st.write('\n------------------\n')
lerning.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # !pip install -q transformers
4
+
5
+ import numpy as np
6
+ # import pandas as pd
7
+ import re
8
+ # import random
9
+
10
+ import torch
11
+ # from tqdm.notebook import tqdm
12
+ import transformers
13
+ # from torch.optim import AdamW
14
+
15
+ import textwrap
16
+
17
+ # Загружаем токенайзер модели
18
+ from transformers import GPT2Tokenizer
19
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
20
+
21
+ # import re
22
+ with open('anekdoty.txt', encoding='utf8') as f:
23
+ text = f.read()
24
+
25
+ text = re.sub('\n{2,}', '\n', text)
26
+ print(text[:1000])
27
+
28
+
29
+ # токенизируем текст
30
+ tokens = tokenizer.encode(text, add_special_tokens=True)
31
+ tokens = np.array(tokens)
32
+ print(len(tokens))
33
+ tokens[:10]
34
+
35
+
36
+ # разбиваем на train и test
37
+
38
+ l = len(tokens)//15
39
+ train = []
40
+ test = []
41
+ for i in range(15):
42
+ if i%5 > 0:
43
+ train.extend(tokens[i*l: (i+1)*l])
44
+ else:
45
+ test.extend(tokens[i*l: (i+1)*l])
46
+ train = np.array(train)
47
+ test = np.array(test)
48
+
49
+ print(len(tokens), len(train), len(test))
50
+
51
+
52
+
53
+ from transformers import GPT2LMHeadModel
54
+
55
+ # Эту модель просто подгружаем и не будем дообучать
56
+ model_init = GPT2LMHeadModel.from_pretrained(
57
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
58
+ output_attentions = False,
59
+ output_hidden_states = False,
60
+ )
61
+
62
+
63
+ # Эту модель подгрузим и далее обучим
64
+ model = GPT2LMHeadModel.from_pretrained(
65
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
66
+ output_attentions = False,
67
+ output_hidden_states = False,
68
+ )
69
+
70
+ model.to(device);
71
+ model_init.to(device);
72
+
73
+
74
+ batch_size = 8
75
+ max_len = 256
76
+ epochs = 5
77
+
78
+ n_train = len(train)//(batch_size*max_len)
79
+ n_test = len(test)//(batch_size*max_len)
80
+ print(n_train, n_test)
81
+
82
+ # устанавливаем оптимизатор
83
+ optimizer = AdamW(model.parameters(), lr = 1e-5, eps = 1e-8)
84
+
85
+ # трансформеры с трудом обучаются, для них нужны разные способы повышения
86
+ # эффективности градиентного спуска
87
+ total_steps = n_train * epochs
88
+ scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
89
+ num_warmup_steps = 0,
90
+ num_training_steps = total_steps)
91
+
92
+
93
+ # зададим точность, хотя ориентироваться будем на качество генерации
94
+ def accuracy(y_true, logits):
95
+ return torch.mean((y_true[1:] == torch.argmax(logits, dim=2)[:-1]).float()).detach().cpu().numpy()
96
+
97
+
98
+
99
+ # готовим тензоры для обучения размера [batch_size, max_len]
100
+
101
+ def prep_tensors(x, i, batch_size=batch_size, max_len=max_len):
102
+ batch_ids = x[i*batch_size*max_len: (i+1)*batch_size*max_len]
103
+ batch_ids = batch_ids.reshape(batch_size, max_len)
104
+ batch_ids = torch.tensor(batch_ids).to(device)
105
+ return batch_ids
106
+
107
+
108
+ # обучающий цикл
109
+ for epoch in range(1, epochs+1):
110
+ print(f'epoch {epoch}/{epochs} : training')
111
+
112
+ train_loss = []
113
+ train_acc = []
114
+ model.train()
115
+ pbar = range(n_train)
116
+ # pbar = tqdm(range(n_train))
117
+ for i in pbar:
118
+ batch_ids = prep_tensors(train, i)
119
+
120
+ model.zero_grad()
121
+ loss, logits, _ = model(batch_ids,
122
+ token_type_ids=None,
123
+ labels=batch_ids
124
+ ).values()
125
+
126
+ loss.backward()
127
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
128
+ optimizer.step()
129
+ scheduler.step()
130
+
131
+ train_loss.append(loss.item())
132
+ train_acc.append(accuracy(batch_ids, logits))
133
+ print(f'acc {np.mean(train_acc):.4f} loss {np.mean(train_loss):.4f}')
134
+ # pbar.set_description(f'acc {np.mean(train_acc):.4f} loss {np.mean(train_loss):.4f}', refresh=True)
135
+
136
+
137
+ print(f'epoch {epoch}/{epochs} : validation')
138
+ model.eval()
139
+ val_acc = []
140
+ val_loss = []
141
+ pbar = range(n_test)
142
+ # pbar = tqdm(range(n_test))
143
+ for i in pbar:
144
+ batch_ids = prep_tensors(test, i)
145
+ with torch.no_grad():
146
+ loss, logits, _ = model(batch_ids,
147
+ token_type_ids=None,
148
+ labels=batch_ids
149
+ ).values()
150
+
151
+ val_loss.append(loss.item())
152
+ val_acc.append(accuracy(batch_ids, logits))
153
+ print(f'acc {np.mean(val_acc):.4f} loss {np.mean(val_loss):.4f}')
154
+ # pbar.set_description(f'acc {np.mean(val_acc):.4f} loss {np.mean(val_loss):.4f}', refresh=True)
155
+
156
+
157
+ # Применим модель, которую мы не дообучали: просто для понимания разницы между дообученной на собственных данных моделью и предобученной.
158
+ # https://huggingface.co/transformers/main_classes/model.html#transformers.generation_utils.GenerationMixin.generate
159
+ # модель без дообучения
160
+
161
+ # prompt – строка, которую модель примет на вход и продолжит
162
+ prompt = 'Мужик спрашивает официанта'
163
+
164
+ # токенизируем строку
165
+ prompt = tokenizer.encode(prompt, return_tensors='pt').to(device)
166
+
167
+ # out будет содержать результаты генерации в виде списка
168
+ out = model_init.generate(
169
+ # входная строка
170
+ input_ids=prompt,
171
+ # максимальная длина генерируемой последовательности
172
+ max_length=250,
173
+ # num_beams
174
+ num_beams=5,
175
+ # применяем сэмплирование
176
+ do_sample=True,
177
+ # применяем температуру
178
+ temperature=55.,
179
+ # топ слов по вероятности
180
+ top_k=50,
181
+ # топ слов по суммарной вероятности
182
+ top_p=0.6,
183
+ # сколько (постараться) не повторять n_gram подряд
184
+ no_repeat_ngram_size=3,
185
+ # сколько вернуть генераций
186
+ num_return_sequences=7,
187
+ ).cpu().numpy()
188
+
189
+ # out содержит результаты
190
+
191
+
192
+ # декодируем и печатаем
193
+ for out_ in out:
194
+ print(tokenizer.decode(out_))
195
+
196
+
197
+ # дообученная модель
198
+ with torch.inference_mode():
199
+ prompt = 'Мужик спрашивает официанта'
200
+ prompt = tokenizer.encode(prompt, return_tensors='pt').to(device)
201
+ out = model.generate(
202
+ input_ids=prompt,
203
+ max_length=150,
204
+ num_beams=1,
205
+ do_sample=True,
206
+ temperature=1.,
207
+ top_k=5,
208
+ top_p=0.6,
209
+ no_repeat_ngram_size=2,
210
+ num_return_sequences=7,
211
+ ).cpu().numpy()
212
+ for out_ in out:
213
+ print(textwrap.fill(tokenizer.decode(out_), 100), end='\n------------------\n')
214
+
215
+
216
+
217
+ # Сохраняем веса обученной модели
218
+ torch.save(model.state_dict(), 'model.pt')
219
+
220
+ # Задаем класс модели (уже в streamlit/tg_bot)
221
+ model_finetuned = GPT2LMHeadModel.from_pretrained(
222
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
223
+ output_attentions = False,
224
+ output_hidden_states = False,
225
+ )
226
+
227
+ # Вешаем сохраненные веса на нашу модель
228
+ model = model_finetuned.load_state_dict(torch.load('model.pt'))
229
+
230
+ # -> <All keys matched successfully>
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1d617290b6cd70a70e637b9478be1f1c47b6c9ca361f59eb1e68382c206d4fc
3
+ size 551310221
requirements.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==4.2.0
2
+ attrs==22.1.0
3
+ blinker==1.5
4
+ cachetools==5.2.0
5
+ certifi==2022.12.7
6
+ charset-normalizer==2.1.1
7
+ click==8.1.3
8
+ commonmark==0.9.1
9
+ decorator==5.1.1
10
+ entrypoints==0.4
11
+ filelock==3.8.2
12
+ gitdb==4.0.10
13
+ GitPython==3.1.29
14
+ huggingface-hub==0.11.1
15
+ idna==3.4
16
+ importlib-metadata==5.1.0
17
+ Jinja2==3.1.2
18
+ jsonschema==4.17.3
19
+ MarkupSafe==2.1.1
20
+ numpy==1.23.5
21
+ nvidia-cublas-cu11==11.10.3.66
22
+ nvidia-cuda-nvrtc-cu11==11.7.99
23
+ nvidia-cuda-runtime-cu11==11.7.99
24
+ nvidia-cudnn-cu11==8.5.0.96
25
+ packaging==22.0
26
+ pandas==1.5.2
27
+ Pillow==9.3.0
28
+ protobuf==3.20.3
29
+ pyarrow==10.0.1
30
+ pydeck==0.8.0
31
+ Pygments==2.13.0
32
+ Pympler==1.0.1
33
+ pyrsistent==0.19.2
34
+ python-dateutil==2.8.2
35
+ pytz==2022.6
36
+ pytz-deprecation-shim==0.1.0.post0
37
+ PyYAML==6.0
38
+ regex==2022.10.31
39
+ requests==2.28.1
40
+ rich==12.6.0
41
+ semver==2.13.0
42
+ six==1.16.0
43
+ smmap==5.0.0
44
+ streamlit==1.16.0
45
+ tokenizers==0.13.2
46
+ toml==0.10.2
47
+ toolz==0.12.0
48
+ torch==1.13.1
49
+ tornado==6.2
50
+ tqdm==4.64.1
51
+ transformers==4.25.1
52
+ typing_extensions==4.4.0
53
+ tzdata==2022.7
54
+ tzlocal==4.2
55
+ urllib3==1.26.13
56
+ validators==0.20.0
57
+ watchdog==2.2.0
58
+ zipp==3.11.0