reproduceability
Browse files- enviroment.yml +0 -0
- gpt2ia_v1.ipynb +260 -0
enviroment.yml
ADDED
Binary file (18.9 kB). View file
|
|
gpt2ia_v1.ipynb
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import pandas as pd\n",
|
10 |
+
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
|
11 |
+
"import numpy as np\n",
|
12 |
+
"import random\n",
|
13 |
+
"import torch\n",
|
14 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
15 |
+
"from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup\n",
|
16 |
+
"from tqdm import tqdm, trange\n",
|
17 |
+
"import torch.nn.functional as F\n",
|
18 |
+
"import csv\n",
|
19 |
+
"from transformers import TextDataset, DataCollatorForLanguageModeling\n",
|
20 |
+
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
|
21 |
+
"from transformers import Trainer, TrainingArguments"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 2,
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [
|
29 |
+
{
|
30 |
+
"name": "stdout",
|
31 |
+
"output_type": "stream",
|
32 |
+
"text": [
|
33 |
+
"100000\n"
|
34 |
+
]
|
35 |
+
}
|
36 |
+
],
|
37 |
+
"source": [
|
38 |
+
"with open(\"./titulos.txt\") as file:\n",
|
39 |
+
" manchetes = [line.rstrip() for line in file]\n",
|
40 |
+
"print(len(manchetes))"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 3,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"def load_dataset(file_path, tokenizer, block_size = 128):\n",
|
50 |
+
" dataset = TextDataset(\n",
|
51 |
+
" tokenizer = tokenizer,\n",
|
52 |
+
" file_path = file_path,\n",
|
53 |
+
" block_size = block_size,\n",
|
54 |
+
" )\n",
|
55 |
+
" return dataset\n",
|
56 |
+
"\n",
|
57 |
+
"\n",
|
58 |
+
"def load_data_collator(tokenizer, mlm = False):\n",
|
59 |
+
" data_collator = DataCollatorForLanguageModeling(\n",
|
60 |
+
" tokenizer=tokenizer,\n",
|
61 |
+
" mlm=mlm,\n",
|
62 |
+
" )\n",
|
63 |
+
" return data_collator\n",
|
64 |
+
"\n",
|
65 |
+
"\n",
|
66 |
+
"def train(train_file_path,model_name,\n",
|
67 |
+
" output_dir,\n",
|
68 |
+
" overwrite_output_dir,\n",
|
69 |
+
" per_device_train_batch_size,\n",
|
70 |
+
" num_train_epochs,\n",
|
71 |
+
" save_steps):\n",
|
72 |
+
" tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n",
|
73 |
+
" train_dataset = load_dataset(train_file_path, tokenizer)\n",
|
74 |
+
" data_collator = load_data_collator(tokenizer)\n",
|
75 |
+
"\n",
|
76 |
+
" tokenizer.save_pretrained(output_dir)\n",
|
77 |
+
"\n",
|
78 |
+
" model = GPT2LMHeadModel.from_pretrained(model_name)\n",
|
79 |
+
"\n",
|
80 |
+
" model.save_pretrained(output_dir)\n",
|
81 |
+
"\n",
|
82 |
+
" training_args = TrainingArguments(\n",
|
83 |
+
" output_dir=output_dir,\n",
|
84 |
+
" overwrite_output_dir=overwrite_output_dir,\n",
|
85 |
+
" per_device_train_batch_size=per_device_train_batch_size,\n",
|
86 |
+
" num_train_epochs=num_train_epochs,\n",
|
87 |
+
" )\n",
|
88 |
+
"\n",
|
89 |
+
" trainer = Trainer(\n",
|
90 |
+
" model=model,\n",
|
91 |
+
" args=training_args,\n",
|
92 |
+
" data_collator=data_collator,\n",
|
93 |
+
" train_dataset=train_dataset,\n",
|
94 |
+
" )\n",
|
95 |
+
"\n",
|
96 |
+
" trainer.train()\n",
|
97 |
+
" trainer.save_model()"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 4,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"train_file_path = \"./titulos.txt\"\n",
|
107 |
+
"model_name = 'gpt2'\n",
|
108 |
+
"output_dir = './result'\n",
|
109 |
+
"overwrite_output_dir = False\n",
|
110 |
+
"per_device_train_batch_size = 8\n",
|
111 |
+
"num_train_epochs = 5.0\n",
|
112 |
+
"save_steps = 500"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": 5,
|
118 |
+
"metadata": {},
|
119 |
+
"outputs": [
|
120 |
+
{
|
121 |
+
"name": "stderr",
|
122 |
+
"output_type": "stream",
|
123 |
+
"text": [
|
124 |
+
"c:\\Users\\yamgc\\miniconda3\\envs\\choqueianotebook1\\lib\\site-packages\\transformers\\data\\datasets\\language_modeling.py:53: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py\n",
|
125 |
+
" warnings.warn(\n"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"data": {
|
130 |
+
"application/vnd.jupyter.widget-view+json": {
|
131 |
+
"model_id": "f791b828d56f4844a369b75f66ec071a",
|
132 |
+
"version_major": 2,
|
133 |
+
"version_minor": 0
|
134 |
+
},
|
135 |
+
"text/plain": [
|
136 |
+
" 0%| | 0/11455 [00:00<?, ?it/s]"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
"metadata": {},
|
140 |
+
"output_type": "display_data"
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stdout",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"{'loss': 4.0109, 'learning_rate': 4.7817546922741165e-05, 'epoch': 0.22}\n",
|
147 |
+
"{'loss': 3.3335, 'learning_rate': 4.563509384548232e-05, 'epoch': 0.44}\n",
|
148 |
+
"{'loss': 3.0906, 'learning_rate': 4.3452640768223485e-05, 'epoch': 0.65}\n",
|
149 |
+
"{'loss': 2.9653, 'learning_rate': 4.127018769096465e-05, 'epoch': 0.87}\n",
|
150 |
+
"{'loss': 2.855, 'learning_rate': 3.9087734613705804e-05, 'epoch': 1.09}\n",
|
151 |
+
"{'loss': 2.764, 'learning_rate': 3.690528153644697e-05, 'epoch': 1.31}\n",
|
152 |
+
"{'loss': 2.7204, 'learning_rate': 3.472282845918813e-05, 'epoch': 1.53}\n",
|
153 |
+
"{'loss': 2.6883, 'learning_rate': 3.2540375381929286e-05, 'epoch': 1.75}\n",
|
154 |
+
"{'loss': 2.6503, 'learning_rate': 3.035792230467045e-05, 'epoch': 1.96}\n",
|
155 |
+
"{'loss': 2.5871, 'learning_rate': 2.817546922741161e-05, 'epoch': 2.18}\n",
|
156 |
+
"{'loss': 2.5675, 'learning_rate': 2.5993016150152772e-05, 'epoch': 2.4}\n",
|
157 |
+
"{'loss': 2.5442, 'learning_rate': 2.3810563072893935e-05, 'epoch': 2.62}\n",
|
158 |
+
"{'loss': 2.5271, 'learning_rate': 2.1628109995635095e-05, 'epoch': 2.84}\n",
|
159 |
+
"{'loss': 2.4952, 'learning_rate': 1.9445656918376258e-05, 'epoch': 3.06}\n",
|
160 |
+
"{'loss': 2.4653, 'learning_rate': 1.7263203841117417e-05, 'epoch': 3.27}\n",
|
161 |
+
"{'loss': 2.4566, 'learning_rate': 1.5080750763858579e-05, 'epoch': 3.49}\n",
|
162 |
+
"{'loss': 2.4551, 'learning_rate': 1.2898297686599738e-05, 'epoch': 3.71}\n",
|
163 |
+
"{'loss': 2.4445, 'learning_rate': 1.07158446093409e-05, 'epoch': 3.93}\n",
|
164 |
+
"{'loss': 2.4247, 'learning_rate': 8.533391532082061e-06, 'epoch': 4.15}\n",
|
165 |
+
"{'loss': 2.4052, 'learning_rate': 6.3509384548232216e-06, 'epoch': 4.36}\n",
|
166 |
+
"{'loss': 2.4033, 'learning_rate': 4.168485377564382e-06, 'epoch': 4.58}\n",
|
167 |
+
"{'loss': 2.3983, 'learning_rate': 1.9860323003055434e-06, 'epoch': 4.8}\n",
|
168 |
+
"{'train_runtime': 2081.9794, 'train_samples_per_second': 44.004, 'train_steps_per_second': 5.502, 'train_loss': 2.6816321317621945, 'epoch': 5.0}\n"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"source": [
|
173 |
+
"train(\n",
|
174 |
+
" train_file_path=train_file_path,\n",
|
175 |
+
" model_name=model_name,\n",
|
176 |
+
" output_dir=output_dir,\n",
|
177 |
+
" overwrite_output_dir=overwrite_output_dir,\n",
|
178 |
+
" per_device_train_batch_size=per_device_train_batch_size,\n",
|
179 |
+
" num_train_epochs=num_train_epochs,\n",
|
180 |
+
" save_steps=save_steps\n",
|
181 |
+
")"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": 8,
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"from transformers import PreTrainedTokenizerFast, GPT2LMHeadModel, GPT2TokenizerFast, GPT2Tokenizer\n",
|
191 |
+
"\n",
|
192 |
+
"def load_model(model_path):\n",
|
193 |
+
" model = GPT2LMHeadModel.from_pretrained(model_path)\n",
|
194 |
+
" return model\n",
|
195 |
+
"\n",
|
196 |
+
"\n",
|
197 |
+
"def load_tokenizer(tokenizer_path):\n",
|
198 |
+
" tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)\n",
|
199 |
+
" return tokenizer\n",
|
200 |
+
"\n",
|
201 |
+
"\n",
|
202 |
+
"def generate_text(sequence, max_length):\n",
|
203 |
+
" model_path = \"./result\"\n",
|
204 |
+
" model = load_model(model_path)\n",
|
205 |
+
" tokenizer = load_tokenizer(model_path)\n",
|
206 |
+
" ids = tokenizer.encode(f'{sequence}', return_tensors='pt')\n",
|
207 |
+
" final_outputs = model.generate(\n",
|
208 |
+
" ids,\n",
|
209 |
+
" do_sample=True,\n",
|
210 |
+
" max_length=max_length,\n",
|
211 |
+
" pad_token_id=model.config.eos_token_id,\n",
|
212 |
+
" top_k=50,\n",
|
213 |
+
" top_p=0.95,\n",
|
214 |
+
" )\n",
|
215 |
+
" print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"execution_count": null,
|
221 |
+
"metadata": {},
|
222 |
+
"outputs": [
|
223 |
+
{
|
224 |
+
"name": "stdout",
|
225 |
+
"output_type": "stream",
|
226 |
+
"text": [
|
227 |
+
"Juliette rebate críticas à mostra em clique sexy em evento\n",
|
228 |
+
"\n"
|
229 |
+
]
|
230 |
+
}
|
231 |
+
],
|
232 |
+
"source": [
|
233 |
+
"sequence = \"Juliette\"\n",
|
234 |
+
"max_len = 19\n",
|
235 |
+
"generate_text(sequence, max_len)"
|
236 |
+
]
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"metadata": {
|
240 |
+
"kernelspec": {
|
241 |
+
"display_name": "choqueianotebook1",
|
242 |
+
"language": "python",
|
243 |
+
"name": "python3"
|
244 |
+
},
|
245 |
+
"language_info": {
|
246 |
+
"codemirror_mode": {
|
247 |
+
"name": "ipython",
|
248 |
+
"version": 3
|
249 |
+
},
|
250 |
+
"file_extension": ".py",
|
251 |
+
"mimetype": "text/x-python",
|
252 |
+
"name": "python",
|
253 |
+
"nbconvert_exporter": "python",
|
254 |
+
"pygments_lexer": "ipython3",
|
255 |
+
"version": "3.9.18"
|
256 |
+
}
|
257 |
+
},
|
258 |
+
"nbformat": 4,
|
259 |
+
"nbformat_minor": 2
|
260 |
+
}
|