Upload Untitled0.ipynb
Browse files- Untitled0.ipynb +872 -0
Untitled0.ipynb
ADDED
@@ -0,0 +1,872 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"name": "Untitled0.ipynb",
|
7 |
+
"private_outputs": true,
|
8 |
+
"provenance": []
|
9 |
+
},
|
10 |
+
"kernelspec": {
|
11 |
+
"name": "python3",
|
12 |
+
"display_name": "Python 3"
|
13 |
+
},
|
14 |
+
"language_info": {
|
15 |
+
"name": "python"
|
16 |
+
},
|
17 |
+
"accelerator": "GPU"
|
18 |
+
},
|
19 |
+
"cells": [
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"metadata": {
|
23 |
+
"id": "-ExZYuS4whSi"
|
24 |
+
},
|
25 |
+
"source": [
|
26 |
+
"from google.colab import drive\n",
|
27 |
+
"drive.mount('/content/drive/')"
|
28 |
+
],
|
29 |
+
"execution_count": null,
|
30 |
+
"outputs": []
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"metadata": {
|
35 |
+
"id": "iU-h1ApL0bue"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"!pip -q install transformers"
|
39 |
+
],
|
40 |
+
"execution_count": null,
|
41 |
+
"outputs": []
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"metadata": {
|
46 |
+
"id": "a0P1MPcH1IaF"
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"import os\n",
|
50 |
+
"os.chdir(\"/content/drive/My Drive/Colab Notebooks\")"
|
51 |
+
],
|
52 |
+
"execution_count": null,
|
53 |
+
"outputs": []
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"metadata": {
|
58 |
+
"id": "FY23HIap1K3w"
|
59 |
+
},
|
60 |
+
"source": [
|
61 |
+
"import glob\n",
|
62 |
+
"import logging\n",
|
63 |
+
"import os\n",
|
64 |
+
"import pickle\n",
|
65 |
+
"import random\n",
|
66 |
+
"import re\n",
|
67 |
+
"import shutil\n",
|
68 |
+
"from typing import Dict, List, Tuple\n",
|
69 |
+
"\n",
|
70 |
+
"import numpy as np\n",
|
71 |
+
"import pandas as pd\n",
|
72 |
+
"\n",
|
73 |
+
"from sklearn.model_selection import train_test_split\n",
|
74 |
+
"\n",
|
75 |
+
"from torch.nn.utils.rnn import pad_sequence\n",
|
76 |
+
"from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n",
|
77 |
+
"from torch.utils.data.distributed import DistributedSampler\n",
|
78 |
+
"from tqdm.notebook import tqdm, trange\n",
|
79 |
+
"\n",
|
80 |
+
"from pathlib import Path\n",
|
81 |
+
"\n",
|
82 |
+
"from transformers import (\n",
|
83 |
+
" MODEL_WITH_LM_HEAD_MAPPING,\n",
|
84 |
+
" WEIGHTS_NAME,\n",
|
85 |
+
" AdamW,\n",
|
86 |
+
" AutoConfig,\n",
|
87 |
+
" PreTrainedModel,\n",
|
88 |
+
" PreTrainedTokenizer,\n",
|
89 |
+
" get_linear_schedule_with_warmup,\n",
|
90 |
+
")\n",
|
91 |
+
"\n",
|
92 |
+
"\n",
|
93 |
+
"try:\n",
|
94 |
+
" from torch.utils.tensorboard import SummaryWriter\n",
|
95 |
+
"except ImportError:\n",
|
96 |
+
" from tensorboardX import SummaryWriter"
|
97 |
+
],
|
98 |
+
"execution_count": null,
|
99 |
+
"outputs": []
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"metadata": {
|
104 |
+
"id": "4pldsjBT1QqG"
|
105 |
+
},
|
106 |
+
"source": [
|
107 |
+
"data= pd.read_csv('/sukuna.csv')"
|
108 |
+
],
|
109 |
+
"execution_count": null,
|
110 |
+
"outputs": []
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"metadata": {
|
115 |
+
"id": "hESxj45g2PKd"
|
116 |
+
},
|
117 |
+
"source": [
|
118 |
+
"data.sample(6)"
|
119 |
+
],
|
120 |
+
"execution_count": null,
|
121 |
+
"outputs": []
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"metadata": {
|
126 |
+
"id": "UaHgCPjf2Ryg"
|
127 |
+
},
|
128 |
+
"source": [
|
129 |
+
"CHARACTER_NAME = 'Sukuna'"
|
130 |
+
],
|
131 |
+
"execution_count": null,
|
132 |
+
"outputs": []
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"metadata": {
|
137 |
+
"id": "7MU-Ocxw2W9X"
|
138 |
+
},
|
139 |
+
"source": [
|
140 |
+
"contexted = []\n",
|
141 |
+
"\n",
|
142 |
+
"# context window of size 7\n",
|
143 |
+
"n = 7\n",
|
144 |
+
"\n",
|
145 |
+
"for i in data[data.name == CHARACTER_NAME].index:\n",
|
146 |
+
" if i < n:\n",
|
147 |
+
" continue\n",
|
148 |
+
" row = []\n",
|
149 |
+
" prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n",
|
150 |
+
" for j in range(i, prev, -1):\n",
|
151 |
+
" row.append(data.line[j])\n",
|
152 |
+
" contexted.append(row)\n",
|
153 |
+
"\n",
|
154 |
+
"columns = ['response', 'context'] \n",
|
155 |
+
"columns = columns + ['context/' + str(i) for i in range(n - 1)]\n",
|
156 |
+
"\n",
|
157 |
+
"df = pd.DataFrame.from_records(contexted, columns=columns)"
|
158 |
+
],
|
159 |
+
"execution_count": null,
|
160 |
+
"outputs": []
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"metadata": {
|
165 |
+
"id": "HVl-w-TE2ZQl"
|
166 |
+
},
|
167 |
+
"source": [
|
168 |
+
"df.sample(6)"
|
169 |
+
],
|
170 |
+
"execution_count": null,
|
171 |
+
"outputs": []
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"metadata": {
|
176 |
+
"id": "zzPDyi9p2bRf"
|
177 |
+
},
|
178 |
+
"source": [
|
179 |
+
"trn_df, val_df = train_test_split(df, test_size=0.1)\n",
|
180 |
+
"trn_df.head()"
|
181 |
+
],
|
182 |
+
"execution_count": null,
|
183 |
+
"outputs": []
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"metadata": {
|
188 |
+
"id": "7Mj6dxsk2dlU"
|
189 |
+
},
|
190 |
+
"source": [
|
191 |
+
"def construct_conv(row, tokenizer, eos = True):\n",
|
192 |
+
" flatten = lambda l: [item for sublist in l for item in sublist]\n",
|
193 |
+
" conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n",
|
194 |
+
" conv = flatten(conv)\n",
|
195 |
+
" return conv\n",
|
196 |
+
"\n",
|
197 |
+
"class ConversationDataset(Dataset):\n",
|
198 |
+
" def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n",
|
199 |
+
"\n",
|
200 |
+
" block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n",
|
201 |
+
"\n",
|
202 |
+
" directory = args.cache_dir\n",
|
203 |
+
" cached_features_file = os.path.join(\n",
|
204 |
+
" directory, args.model_type + \"_cached_lm_\" + str(block_size)\n",
|
205 |
+
" )\n",
|
206 |
+
"\n",
|
207 |
+
" if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
|
208 |
+
" logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
|
209 |
+
" with open(cached_features_file, \"rb\") as handle:\n",
|
210 |
+
" self.examples = pickle.load(handle)\n",
|
211 |
+
" else:\n",
|
212 |
+
" logger.info(\"Creating features from dataset file at %s\", directory)\n",
|
213 |
+
"\n",
|
214 |
+
" self.examples = []\n",
|
215 |
+
" for _, row in df.iterrows():\n",
|
216 |
+
" conv = construct_conv(row, tokenizer)\n",
|
217 |
+
" self.examples.append(conv)\n",
|
218 |
+
"\n",
|
219 |
+
" logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
|
220 |
+
" with open(cached_features_file, \"wb\") as handle:\n",
|
221 |
+
" pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
|
222 |
+
"\n",
|
223 |
+
" def __len__(self):\n",
|
224 |
+
" return len(self.examples)\n",
|
225 |
+
"\n",
|
226 |
+
" def __getitem__(self, item):\n",
|
227 |
+
" return torch.tensor(self.examples[item], dtype=torch.long)"
|
228 |
+
],
|
229 |
+
"execution_count": null,
|
230 |
+
"outputs": []
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"metadata": {
|
235 |
+
"id": "HDSfFNhZ2j5U"
|
236 |
+
},
|
237 |
+
"source": [
|
238 |
+
"def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n",
|
239 |
+
" return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n",
|
240 |
+
"\n",
|
241 |
+
"\n",
|
242 |
+
"def set_seed(args):\n",
|
243 |
+
" random.seed(args.seed)\n",
|
244 |
+
" np.random.seed(args.seed)\n",
|
245 |
+
" torch.manual_seed(args.seed)\n",
|
246 |
+
" if args.n_gpu > 0:\n",
|
247 |
+
" torch.cuda.manual_seed_all(args.seed)\n",
|
248 |
+
"\n",
|
249 |
+
"\n",
|
250 |
+
"def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n",
|
251 |
+
" ordering_and_checkpoint_path = []\n",
|
252 |
+
"\n",
|
253 |
+
" glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n",
|
254 |
+
"\n",
|
255 |
+
" for path in glob_checkpoints:\n",
|
256 |
+
" if use_mtime:\n",
|
257 |
+
" ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n",
|
258 |
+
" else:\n",
|
259 |
+
" regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n",
|
260 |
+
" if regex_match and regex_match.groups():\n",
|
261 |
+
" ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n",
|
262 |
+
"\n",
|
263 |
+
" checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n",
|
264 |
+
" checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n",
|
265 |
+
" return checkpoints_sorted\n",
|
266 |
+
"\n",
|
267 |
+
"\n",
|
268 |
+
"def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n",
|
269 |
+
" if not args.save_total_limit:\n",
|
270 |
+
" return\n",
|
271 |
+
" if args.save_total_limit <= 0:\n",
|
272 |
+
" return\n",
|
273 |
+
"\n",
|
274 |
+
" # Check if we should delete older checkpoint(s)\n",
|
275 |
+
" checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n",
|
276 |
+
" if len(checkpoints_sorted) <= args.save_total_limit:\n",
|
277 |
+
" return\n",
|
278 |
+
"\n",
|
279 |
+
" number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n",
|
280 |
+
" checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n",
|
281 |
+
" for checkpoint in checkpoints_to_be_deleted:\n",
|
282 |
+
" logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n",
|
283 |
+
" shutil.rmtree(checkpoint)"
|
284 |
+
],
|
285 |
+
"execution_count": null,
|
286 |
+
"outputs": []
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"metadata": {
|
291 |
+
"id": "5VII0lGJ2nmz"
|
292 |
+
},
|
293 |
+
"source": [
|
294 |
+
"from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n",
|
295 |
+
"import torch\n",
|
296 |
+
"\n",
|
297 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n",
|
298 |
+
"model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")"
|
299 |
+
],
|
300 |
+
"execution_count": null,
|
301 |
+
"outputs": []
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"metadata": {
|
306 |
+
"id": "2U7RIR4L2yNa"
|
307 |
+
},
|
308 |
+
"source": [
|
309 |
+
"\"\"\"\n",
|
310 |
+
"Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n",
|
311 |
+
"GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n",
|
312 |
+
"using a masked language modeling (MLM) loss.\n",
|
313 |
+
"\"\"\"\n",
|
314 |
+
"\n",
|
315 |
+
"# Configs\n",
|
316 |
+
"logger = logging.getLogger(__name__)\n",
|
317 |
+
"\n",
|
318 |
+
"MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n",
|
319 |
+
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"
|
320 |
+
],
|
321 |
+
"execution_count": null,
|
322 |
+
"outputs": []
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"metadata": {
|
327 |
+
"id": "r3VZpWtb3hFU"
|
328 |
+
},
|
329 |
+
"source": [
|
330 |
+
"class Args():\n",
|
331 |
+
" def __init__(self):\n",
|
332 |
+
" self.output_dir = 'output-small'\n",
|
333 |
+
" self.model_type = 'gpt2'\n",
|
334 |
+
" self.model_name_or_path = 'microsoft/DialoGPT-small'\n",
|
335 |
+
" self.config_name = 'microsoft/DialoGPT-small'\n",
|
336 |
+
" self.tokenizer_name = 'microsoft/DialoGPT-small'\n",
|
337 |
+
" self.cache_dir = 'cached'\n",
|
338 |
+
" self.block_size = 512\n",
|
339 |
+
" self.do_train = True\n",
|
340 |
+
" self.do_eval = True\n",
|
341 |
+
" self.evaluate_during_training = False\n",
|
342 |
+
" self.per_gpu_train_batch_size = 4\n",
|
343 |
+
" self.per_gpu_eval_batch_size = 4\n",
|
344 |
+
" self.gradient_accumulation_steps = 1\n",
|
345 |
+
" self.learning_rate = 5e-5\n",
|
346 |
+
" self.weight_decay = 0.0\n",
|
347 |
+
" self.adam_epsilon = 1e-8\n",
|
348 |
+
" self.max_grad_norm = 1.0\n",
|
349 |
+
" self.num_train_epochs = 4\n",
|
350 |
+
" self.max_steps = -1\n",
|
351 |
+
" self.warmup_steps = 0\n",
|
352 |
+
" self.logging_steps = 1000\n",
|
353 |
+
" self.save_steps = 3500\n",
|
354 |
+
" self.save_total_limit = None\n",
|
355 |
+
" self.eval_all_checkpoints = False\n",
|
356 |
+
" self.no_cuda = False\n",
|
357 |
+
" self.overwrite_output_dir = True\n",
|
358 |
+
" self.overwrite_cache = True\n",
|
359 |
+
" self.should_continue = False\n",
|
360 |
+
" self.seed = 42\n",
|
361 |
+
" self.local_rank = -1\n",
|
362 |
+
" self.fp16 = False\n",
|
363 |
+
" self.fp16_opt_level = 'O1'\n",
|
364 |
+
"\n",
|
365 |
+
"args = Args()"
|
366 |
+
],
|
367 |
+
"execution_count": null,
|
368 |
+
"outputs": []
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"metadata": {
|
373 |
+
"id": "8m7XgQqW3l_N"
|
374 |
+
},
|
375 |
+
"source": [
|
376 |
+
"def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n",
|
377 |
+
" \"\"\" Train the model \"\"\"\n",
|
378 |
+
" if args.local_rank in [-1, 0]:\n",
|
379 |
+
" tb_writer = SummaryWriter()\n",
|
380 |
+
"\n",
|
381 |
+
" args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
|
382 |
+
"\n",
|
383 |
+
" def collate(examples: List[torch.Tensor]):\n",
|
384 |
+
" if tokenizer._pad_token is None:\n",
|
385 |
+
" return pad_sequence(examples, batch_first=True)\n",
|
386 |
+
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
|
387 |
+
"\n",
|
388 |
+
" train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n",
|
389 |
+
" train_dataloader = DataLoader(\n",
|
390 |
+
" train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n",
|
391 |
+
" )\n",
|
392 |
+
"\n",
|
393 |
+
" if args.max_steps > 0:\n",
|
394 |
+
" t_total = args.max_steps\n",
|
395 |
+
" args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n",
|
396 |
+
" else:\n",
|
397 |
+
" t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n",
|
398 |
+
"\n",
|
399 |
+
" model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n",
|
400 |
+
" model.resize_token_embeddings(len(tokenizer))\n",
|
401 |
+
" # add_special_tokens_(model, tokenizer)\n",
|
402 |
+
"\n",
|
403 |
+
"\n",
|
404 |
+
" # Prepare optimizer and schedule (linear warmup and decay)\n",
|
405 |
+
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
|
406 |
+
" optimizer_grouped_parameters = [\n",
|
407 |
+
" {\n",
|
408 |
+
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
|
409 |
+
" \"weight_decay\": args.weight_decay,\n",
|
410 |
+
" },\n",
|
411 |
+
" {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n",
|
412 |
+
" ]\n",
|
413 |
+
" optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
|
414 |
+
" scheduler = get_linear_schedule_with_warmup(\n",
|
415 |
+
" optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n",
|
416 |
+
" )\n",
|
417 |
+
"\n",
|
418 |
+
" # Check if saved optimizer or scheduler states exist\n",
|
419 |
+
" if (\n",
|
420 |
+
" args.model_name_or_path\n",
|
421 |
+
" and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n",
|
422 |
+
" and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n",
|
423 |
+
" ):\n",
|
424 |
+
" # Load in optimizer and scheduler states\n",
|
425 |
+
" optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n",
|
426 |
+
" scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n",
|
427 |
+
"\n",
|
428 |
+
" if args.fp16:\n",
|
429 |
+
" try:\n",
|
430 |
+
" from apex import amp\n",
|
431 |
+
" except ImportError:\n",
|
432 |
+
" raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
|
433 |
+
" model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n",
|
434 |
+
"\n",
|
435 |
+
" # multi-gpu training (should be after apex fp16 initialization)\n",
|
436 |
+
" if args.n_gpu > 1:\n",
|
437 |
+
" model = torch.nn.DataParallel(model)\n",
|
438 |
+
"\n",
|
439 |
+
" # Distributed training (should be after apex fp16 initialization)\n",
|
440 |
+
" if args.local_rank != -1:\n",
|
441 |
+
" model = torch.nn.parallel.DistributedDataParallel(\n",
|
442 |
+
" model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n",
|
443 |
+
" )\n",
|
444 |
+
"\n",
|
445 |
+
" # Train!\n",
|
446 |
+
" logger.info(\"***** Running training *****\")\n",
|
447 |
+
" logger.info(\" Num examples = %d\", len(train_dataset))\n",
|
448 |
+
" logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n",
|
449 |
+
" logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n",
|
450 |
+
" logger.info(\n",
|
451 |
+
" \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n",
|
452 |
+
" args.train_batch_size\n",
|
453 |
+
" * args.gradient_accumulation_steps\n",
|
454 |
+
" * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n",
|
455 |
+
" )\n",
|
456 |
+
" logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n",
|
457 |
+
" logger.info(\" Total optimization steps = %d\", t_total)\n",
|
458 |
+
"\n",
|
459 |
+
" global_step = 0\n",
|
460 |
+
" epochs_trained = 0\n",
|
461 |
+
" steps_trained_in_current_epoch = 0\n",
|
462 |
+
" # Check if continuing training from a checkpoint\n",
|
463 |
+
" if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n",
|
464 |
+
" try:\n",
|
465 |
+
" # set global_step to gobal_step of last saved checkpoint from model path\n",
|
466 |
+
" checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n",
|
467 |
+
" global_step = int(checkpoint_suffix)\n",
|
468 |
+
" epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n",
|
469 |
+
" steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n",
|
470 |
+
"\n",
|
471 |
+
" logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n",
|
472 |
+
" logger.info(\" Continuing training from epoch %d\", epochs_trained)\n",
|
473 |
+
" logger.info(\" Continuing training from global step %d\", global_step)\n",
|
474 |
+
" logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n",
|
475 |
+
" except ValueError:\n",
|
476 |
+
" logger.info(\" Starting fine-tuning.\")\n",
|
477 |
+
"\n",
|
478 |
+
" tr_loss, logging_loss = 0.0, 0.0\n",
|
479 |
+
"\n",
|
480 |
+
" model.zero_grad()\n",
|
481 |
+
" train_iterator = trange(\n",
|
482 |
+
" epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n",
|
483 |
+
" )\n",
|
484 |
+
" set_seed(args) # Added here for reproducibility\n",
|
485 |
+
" for _ in train_iterator:\n",
|
486 |
+
" epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n",
|
487 |
+
" for step, batch in enumerate(epoch_iterator):\n",
|
488 |
+
"\n",
|
489 |
+
" # Skip past any already trained steps if resuming training\n",
|
490 |
+
" if steps_trained_in_current_epoch > 0:\n",
|
491 |
+
" steps_trained_in_current_epoch -= 1\n",
|
492 |
+
" continue\n",
|
493 |
+
"\n",
|
494 |
+
" inputs, labels = (batch, batch)\n",
|
495 |
+
" if inputs.shape[1] > 1024: continue\n",
|
496 |
+
" inputs = inputs.to(args.device)\n",
|
497 |
+
" labels = labels.to(args.device)\n",
|
498 |
+
" model.train()\n",
|
499 |
+
" outputs = model(inputs, labels=labels)\n",
|
500 |
+
" loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n",
|
501 |
+
"\n",
|
502 |
+
" if args.n_gpu > 1:\n",
|
503 |
+
" loss = loss.mean() # mean() to average on multi-gpu parallel training\n",
|
504 |
+
" if args.gradient_accumulation_steps > 1:\n",
|
505 |
+
" loss = loss / args.gradient_accumulation_steps\n",
|
506 |
+
"\n",
|
507 |
+
" if args.fp16:\n",
|
508 |
+
" with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
|
509 |
+
" scaled_loss.backward()\n",
|
510 |
+
" else:\n",
|
511 |
+
" loss.backward()\n",
|
512 |
+
"\n",
|
513 |
+
" tr_loss += loss.item()\n",
|
514 |
+
" if (step + 1) % args.gradient_accumulation_steps == 0:\n",
|
515 |
+
" if args.fp16:\n",
|
516 |
+
" torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
|
517 |
+
" else:\n",
|
518 |
+
" torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n",
|
519 |
+
" optimizer.step()\n",
|
520 |
+
" scheduler.step() # Update learning rate schedule\n",
|
521 |
+
" model.zero_grad()\n",
|
522 |
+
" global_step += 1\n",
|
523 |
+
"\n",
|
524 |
+
" if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n",
|
525 |
+
" # Log metrics\n",
|
526 |
+
" if (\n",
|
527 |
+
" args.local_rank == -1 and args.evaluate_during_training\n",
|
528 |
+
" ): # Only evaluate when single GPU otherwise metrics may not average well\n",
|
529 |
+
" results = evaluate(args, model, tokenizer)\n",
|
530 |
+
" for key, value in results.items():\n",
|
531 |
+
" tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n",
|
532 |
+
" tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n",
|
533 |
+
" tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n",
|
534 |
+
" logging_loss = tr_loss\n",
|
535 |
+
"\n",
|
536 |
+
" if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n",
|
537 |
+
" checkpoint_prefix = \"checkpoint\"\n",
|
538 |
+
" # Save model checkpoint\n",
|
539 |
+
" output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n",
|
540 |
+
" os.makedirs(output_dir, exist_ok=True)\n",
|
541 |
+
" model_to_save = (\n",
|
542 |
+
" model.module if hasattr(model, \"module\") else model\n",
|
543 |
+
" ) # Take care of distributed/parallel training\n",
|
544 |
+
" model_to_save.save_pretrained(output_dir)\n",
|
545 |
+
" tokenizer.save_pretrained(output_dir)\n",
|
546 |
+
"\n",
|
547 |
+
" torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n",
|
548 |
+
" logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
|
549 |
+
"\n",
|
550 |
+
" _rotate_checkpoints(args, checkpoint_prefix)\n",
|
551 |
+
"\n",
|
552 |
+
" torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n",
|
553 |
+
" torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n",
|
554 |
+
" logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n",
|
555 |
+
"\n",
|
556 |
+
" if args.max_steps > 0 and global_step > args.max_steps:\n",
|
557 |
+
" epoch_iterator.close()\n",
|
558 |
+
" break\n",
|
559 |
+
" if args.max_steps > 0 and global_step > args.max_steps:\n",
|
560 |
+
" train_iterator.close()\n",
|
561 |
+
" break\n",
|
562 |
+
"\n",
|
563 |
+
" if args.local_rank in [-1, 0]:\n",
|
564 |
+
" tb_writer.close()\n",
|
565 |
+
"\n",
|
566 |
+
" return global_step, tr_loss / global_step\n",
|
567 |
+
"\n",
|
568 |
+
"# Evaluation of some model\n",
|
569 |
+
"\n",
|
570 |
+
"def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n",
|
571 |
+
" # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
|
572 |
+
" eval_output_dir = args.output_dir\n",
|
573 |
+
"\n",
|
574 |
+
" eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n",
|
575 |
+
" os.makedirs(eval_output_dir, exist_ok=True)\n",
|
576 |
+
" args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n",
|
577 |
+
" # Note that DistributedSampler samples randomly\n",
|
578 |
+
"\n",
|
579 |
+
" def collate(examples: List[torch.Tensor]):\n",
|
580 |
+
" if tokenizer._pad_token is None:\n",
|
581 |
+
" return pad_sequence(examples, batch_first=True)\n",
|
582 |
+
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
|
583 |
+
"\n",
|
584 |
+
" eval_sampler = SequentialSampler(eval_dataset)\n",
|
585 |
+
" eval_dataloader = DataLoader(\n",
|
586 |
+
" eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n",
|
587 |
+
" )\n",
|
588 |
+
"\n",
|
589 |
+
" # multi-gpu evaluate\n",
|
590 |
+
" if args.n_gpu > 1:\n",
|
591 |
+
" model = torch.nn.DataParallel(model)\n",
|
592 |
+
"\n",
|
593 |
+
" # Eval!\n",
|
594 |
+
" logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
|
595 |
+
" logger.info(\" Num examples = %d\", len(eval_dataset))\n",
|
596 |
+
" logger.info(\" Batch size = %d\", args.eval_batch_size)\n",
|
597 |
+
" eval_loss = 0.0\n",
|
598 |
+
" nb_eval_steps = 0\n",
|
599 |
+
" model.eval()\n",
|
600 |
+
"\n",
|
601 |
+
" for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
|
602 |
+
" inputs, labels = (batch, batch)\n",
|
603 |
+
" inputs = inputs.to(args.device)\n",
|
604 |
+
" labels = labels.to(args.device)\n",
|
605 |
+
"\n",
|
606 |
+
" with torch.no_grad():\n",
|
607 |
+
" outputs = model(inputs, labels=labels)\n",
|
608 |
+
" lm_loss = outputs[0]\n",
|
609 |
+
" eval_loss += lm_loss.mean().item()\n",
|
610 |
+
" nb_eval_steps += 1\n",
|
611 |
+
"\n",
|
612 |
+
" eval_loss = eval_loss / nb_eval_steps\n",
|
613 |
+
" perplexity = torch.exp(torch.tensor(eval_loss))\n",
|
614 |
+
"\n",
|
615 |
+
" result = {\"perplexity\": perplexity}\n",
|
616 |
+
"\n",
|
617 |
+
" output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n",
|
618 |
+
" with open(output_eval_file, \"w\") as writer:\n",
|
619 |
+
" logger.info(\"***** Eval results {} *****\".format(prefix))\n",
|
620 |
+
" for key in sorted(result.keys()):\n",
|
621 |
+
" logger.info(\" %s = %s\", key, str(result[key]))\n",
|
622 |
+
" writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
|
623 |
+
"\n",
|
624 |
+
" return result"
|
625 |
+
],
|
626 |
+
"execution_count": null,
|
627 |
+
"outputs": []
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"cell_type": "code",
|
631 |
+
"metadata": {
|
632 |
+
"id": "nL2b4rsF3uAn"
|
633 |
+
},
|
634 |
+
"source": [
|
635 |
+
"def main(df_trn, df_val):\n",
|
636 |
+
" args = Args()\n",
|
637 |
+
" \n",
|
638 |
+
" if args.should_continue:\n",
|
639 |
+
" sorted_checkpoints = _sorted_checkpoints(args)\n",
|
640 |
+
" if len(sorted_checkpoints) == 0:\n",
|
641 |
+
" raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n",
|
642 |
+
" else:\n",
|
643 |
+
" args.model_name_or_path = sorted_checkpoints[-1]\n",
|
644 |
+
"\n",
|
645 |
+
" if (\n",
|
646 |
+
" os.path.exists(args.output_dir)\n",
|
647 |
+
" and os.listdir(args.output_dir)\n",
|
648 |
+
" and args.do_train\n",
|
649 |
+
" and not args.overwrite_output_dir\n",
|
650 |
+
" and not args.should_continue\n",
|
651 |
+
" ):\n",
|
652 |
+
" raise ValueError(\n",
|
653 |
+
" \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n",
|
654 |
+
" args.output_dir\n",
|
655 |
+
" )\n",
|
656 |
+
" )\n",
|
657 |
+
"\n",
|
658 |
+
" # Setup CUDA, GPU & distributed training\n",
|
659 |
+
" device = torch.device(\"cuda\")\n",
|
660 |
+
" args.n_gpu = torch.cuda.device_count()\n",
|
661 |
+
" args.device = device\n",
|
662 |
+
"\n",
|
663 |
+
" # Setup logging\n",
|
664 |
+
" logging.basicConfig(\n",
|
665 |
+
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
|
666 |
+
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
|
667 |
+
" level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n",
|
668 |
+
" )\n",
|
669 |
+
" logger.warning(\n",
|
670 |
+
" \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n",
|
671 |
+
" args.local_rank,\n",
|
672 |
+
" device,\n",
|
673 |
+
" args.n_gpu,\n",
|
674 |
+
" bool(args.local_rank != -1),\n",
|
675 |
+
" args.fp16,\n",
|
676 |
+
" )\n",
|
677 |
+
"\n",
|
678 |
+
" # Set seed\n",
|
679 |
+
" set_seed(args)\n",
|
680 |
+
"\n",
|
681 |
+
" config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n",
|
682 |
+
" tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n",
|
683 |
+
" model = AutoModelWithLMHead.from_pretrained(\n",
|
684 |
+
" args.model_name_or_path,\n",
|
685 |
+
" from_tf=False,\n",
|
686 |
+
" config=config,\n",
|
687 |
+
" cache_dir=args.cache_dir,\n",
|
688 |
+
" )\n",
|
689 |
+
" model.to(args.device)\n",
|
690 |
+
" \n",
|
691 |
+
" logger.info(\"Training/evaluation parameters %s\", args)\n",
|
692 |
+
"\n",
|
693 |
+
" # Training\n",
|
694 |
+
" if args.do_train:\n",
|
695 |
+
" train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n",
|
696 |
+
"\n",
|
697 |
+
" global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n",
|
698 |
+
" logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n",
|
699 |
+
"\n",
|
700 |
+
" # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n",
|
701 |
+
" if args.do_train:\n",
|
702 |
+
" # Create output directory if needed\n",
|
703 |
+
" os.makedirs(args.output_dir, exist_ok=True)\n",
|
704 |
+
"\n",
|
705 |
+
" logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n",
|
706 |
+
" # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
|
707 |
+
" # They can then be reloaded using `from_pretrained()`\n",
|
708 |
+
" model_to_save = (\n",
|
709 |
+
" model.module if hasattr(model, \"module\") else model\n",
|
710 |
+
" ) # Take care of distributed/parallel training\n",
|
711 |
+
" model_to_save.save_pretrained(args.output_dir)\n",
|
712 |
+
" tokenizer.save_pretrained(args.output_dir)\n",
|
713 |
+
"\n",
|
714 |
+
" # Good practice: save your training arguments together with the trained model\n",
|
715 |
+
" torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n",
|
716 |
+
"\n",
|
717 |
+
" # Load a trained model and vocabulary that you have fine-tuned\n",
|
718 |
+
" model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n",
|
719 |
+
" tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n",
|
720 |
+
" model.to(args.device)\n",
|
721 |
+
"\n",
|
722 |
+
" # Evaluation\n",
|
723 |
+
" results = {}\n",
|
724 |
+
" if args.do_eval and args.local_rank in [-1, 0]:\n",
|
725 |
+
" checkpoints = [args.output_dir]\n",
|
726 |
+
" if args.eval_all_checkpoints:\n",
|
727 |
+
" checkpoints = list(\n",
|
728 |
+
" os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n",
|
729 |
+
" )\n",
|
730 |
+
" logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
|
731 |
+
" logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
|
732 |
+
" for checkpoint in checkpoints:\n",
|
733 |
+
" global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n",
|
734 |
+
" prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n",
|
735 |
+
"\n",
|
736 |
+
" model = AutoModelWithLMHead.from_pretrained(checkpoint)\n",
|
737 |
+
" model.to(args.device)\n",
|
738 |
+
" result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n",
|
739 |
+
" result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n",
|
740 |
+
" results.update(result)\n",
|
741 |
+
"\n",
|
742 |
+
" return results"
|
743 |
+
],
|
744 |
+
"execution_count": null,
|
745 |
+
"outputs": []
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "code",
|
749 |
+
"metadata": {
|
750 |
+
"id": "bx4BrSPQ3yW8"
|
751 |
+
},
|
752 |
+
"source": [
|
753 |
+
"main(trn_df, val_df)"
|
754 |
+
],
|
755 |
+
"execution_count": null,
|
756 |
+
"outputs": []
|
757 |
+
},
|
758 |
+
{
|
759 |
+
"cell_type": "code",
|
760 |
+
"metadata": {
|
761 |
+
"id": "vXESR3FP4Poe"
|
762 |
+
},
|
763 |
+
"source": [
|
764 |
+
"tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n",
|
765 |
+
"model = AutoModelWithLMHead.from_pretrained('output-small')"
|
766 |
+
],
|
767 |
+
"execution_count": null,
|
768 |
+
"outputs": []
|
769 |
+
},
|
770 |
+
{
|
771 |
+
"cell_type": "code",
|
772 |
+
"metadata": {
|
773 |
+
"id": "RXL71M_W4UOp"
|
774 |
+
},
|
775 |
+
"source": [
|
776 |
+
"for step in range(4):\n",
|
777 |
+
" # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
|
778 |
+
" new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n",
|
779 |
+
" # print(new_user_input_ids)\n",
|
780 |
+
"\n",
|
781 |
+
" # append the new user input tokens to the chat history\n",
|
782 |
+
" bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
|
783 |
+
"\n",
|
784 |
+
" # generated a response while limiting the total chat history to 1000 tokens, \n",
|
785 |
+
" chat_history_ids = model.generate(\n",
|
786 |
+
" bot_input_ids, max_length=200,\n",
|
787 |
+
" pad_token_id=tokenizer.eos_token_id, \n",
|
788 |
+
" no_repeat_ngram_size=3, \n",
|
789 |
+
" do_sample=True, \n",
|
790 |
+
" top_k=100, \n",
|
791 |
+
" top_p=0.7,\n",
|
792 |
+
" temperature=0.8\n",
|
793 |
+
" )\n",
|
794 |
+
" \n",
|
795 |
+
" # pretty print last ouput tokens from bot\n",
|
796 |
+
" print(\"Azomekern: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
|
797 |
+
],
|
798 |
+
"execution_count": null,
|
799 |
+
"outputs": []
|
800 |
+
},
|
801 |
+
{
|
802 |
+
"cell_type": "code",
|
803 |
+
"metadata": {
|
804 |
+
"id": "LbRL51QLD7Lf"
|
805 |
+
},
|
806 |
+
"source": [
|
807 |
+
"!sudo apt-get install git-lfs"
|
808 |
+
],
|
809 |
+
"execution_count": null,
|
810 |
+
"outputs": []
|
811 |
+
},
|
812 |
+
{
|
813 |
+
"cell_type": "code",
|
814 |
+
"metadata": {
|
815 |
+
"id": "GaFNUk5QFn6H"
|
816 |
+
},
|
817 |
+
"source": [
|
818 |
+
"!pip install huggingface_hub"
|
819 |
+
],
|
820 |
+
"execution_count": null,
|
821 |
+
"outputs": []
|
822 |
+
},
|
823 |
+
{
|
824 |
+
"cell_type": "code",
|
825 |
+
"metadata": {
|
826 |
+
"id": "EHv1wuvCGMeJ"
|
827 |
+
},
|
828 |
+
"source": [
|
829 |
+
"!huggingface-cli login\n"
|
830 |
+
],
|
831 |
+
"execution_count": null,
|
832 |
+
"outputs": []
|
833 |
+
},
|
834 |
+
{
|
835 |
+
"cell_type": "code",
|
836 |
+
"metadata": {
|
837 |
+
"id": "_Od3UgHDG72I"
|
838 |
+
},
|
839 |
+
"source": [
|
840 |
+
"!sudo apt-get install git-lfs"
|
841 |
+
],
|
842 |
+
"execution_count": null,
|
843 |
+
"outputs": []
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"cell_type": "code",
|
847 |
+
"metadata": {
|
848 |
+
"id": "HeRBcWdoJYYM"
|
849 |
+
},
|
850 |
+
"source": [
|
851 |
+
"!git config --global user.email \"emanuelmaximum40@gmail.com\"\n",
|
852 |
+
"# Tip: using the same email as your huggingface.co account will link your commits to your profile\n",
|
853 |
+
"!git config --global user.name \"Random0-0\""
|
854 |
+
],
|
855 |
+
"execution_count": null,
|
856 |
+
"outputs": []
|
857 |
+
},
|
858 |
+
{
|
859 |
+
"cell_type": "code",
|
860 |
+
"metadata": {
|
861 |
+
"id": "pBZI6c2HLXZl"
|
862 |
+
},
|
863 |
+
"source": [
|
864 |
+
"MY_MODEL_NAME = 'Random0-0/DialoGPT-small-Azomekern'\n",
|
865 |
+
"with open('HuggingFace-API-key.txt', 'rt') as f:\n",
|
866 |
+
" HUGGINGFACE_API_KEY = f.read().strip()"
|
867 |
+
],
|
868 |
+
"execution_count": null,
|
869 |
+
"outputs": []
|
870 |
+
}
|
871 |
+
]
|
872 |
+
}
|