Jhenderson112 commited on
Commit
19f61e3
·
1 Parent(s): fe64f77

Upload 6 files

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. .gitignore +3 -0
  3. Text_Summarization_T5.ipynb +791 -0
  4. app.py +38 -0
  5. requirements.txt +155 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ *.bin
3
+ *.pt
Text_Summarization_T5.ipynb ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "c08e675e-437e-4e7d-baee-bd55dda74611",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Abstractive Text Summarization with T5\n",
9
+ "\n",
10
+ "This implementation uses HuggingFace, especially utilizing `AutoModelForSeq2SeqLM` and `AutoTokenizer`. "
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "a910e4b5-040d-4499-b5c2-32f3e1ac1c34",
16
+ "metadata": {},
17
+ "source": [
18
+ "## Importing libraries"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "d22ee5a9-1981-4883-a926-db37905ec8b6",
25
+ "metadata": {},
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Setup done!\n"
32
+ ]
33
+ }
34
+ ],
35
+ "source": [
36
+ "# Installs\n",
37
+ "!pip install -q evaluate py7zr rouge_score absl-py\n",
38
+ "\n",
39
+ "# Imports here\n",
40
+ "import numpy as np\n",
41
+ "import pandas as pd\n",
42
+ "import matplotlib.pyplot as plt\n",
43
+ "import seaborn as sns\n",
44
+ "import nltk\n",
45
+ "from nltk.tokenize import sent_tokenize\n",
46
+ "nltk.download(\"punkt\")\n",
47
+ "\n",
48
+ "import torch\n",
49
+ "import torch.nn as nn\n",
50
+ "\n",
51
+ "import datasets\n",
52
+ "import transformers\n",
53
+ "from transformers import (\n",
54
+ " AutoModelForSeq2SeqLM,\n",
55
+ " Seq2SeqTrainingArguments,\n",
56
+ " Seq2SeqTrainer,\n",
57
+ " AutoTokenizer\n",
58
+ ")\n",
59
+ "import evaluate\n",
60
+ "\n",
61
+ "# Quality of life fixes\n",
62
+ "import warnings\n",
63
+ "warnings.filterwarnings('ignore')\n",
64
+ "from pprint import pprint\n",
65
+ "\n",
66
+ "import os\n",
67
+ "os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
68
+ "\n",
69
+ "from IPython.display import clear_output\n",
70
+ "\n",
71
+ "print(f\"PyTorch version: {torch.__version__}\")\n",
72
+ "print(f\"Transformers version: {transformers.__version__}\")\n",
73
+ "print(f\"Datasets version: {datasets.__version__}\")\n",
74
+ "print(f\"Evaluate version: {evaluate.__version__}\")\n",
75
+ "\n",
76
+ "# Get the samsum dataset\n",
77
+ "samsum = datasets.load_dataset('samsum')\n",
78
+ "clear_output()\n",
79
+ "print(\"Setup done!\")"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 2,
85
+ "id": "bafa753c-0746-4ece-b5eb-4511c9138b09",
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "data": {
90
+ "text/plain": [
91
+ "'4.27.4'"
92
+ ]
93
+ },
94
+ "execution_count": 2,
95
+ "metadata": {},
96
+ "output_type": "execute_result"
97
+ }
98
+ ],
99
+ "source": [
100
+ "# Verify transformers version\n",
101
+ "transformers.__version__"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "f15204cc-0f21-4dc9-a8e4-429c57b227a9",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Playing around with the dataset"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 3,
115
+ "id": "ba5c1425-a776-4201-97e2-bd420ec112fe",
116
+ "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "data": {
120
+ "text/plain": [
121
+ "DatasetDict({\n",
122
+ " train: Dataset({\n",
123
+ " features: ['id', 'dialogue', 'summary'],\n",
124
+ " num_rows: 14732\n",
125
+ " })\n",
126
+ " test: Dataset({\n",
127
+ " features: ['id', 'dialogue', 'summary'],\n",
128
+ " num_rows: 819\n",
129
+ " })\n",
130
+ " validation: Dataset({\n",
131
+ " features: ['id', 'dialogue', 'summary'],\n",
132
+ " num_rows: 818\n",
133
+ " })\n",
134
+ "})"
135
+ ]
136
+ },
137
+ "execution_count": 3,
138
+ "metadata": {},
139
+ "output_type": "execute_result"
140
+ }
141
+ ],
142
+ "source": [
143
+ "# The samsum dataset shape\n",
144
+ "samsum"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 4,
150
+ "id": "5d53736c-a8c7-4fe3-b8f1-566c1d99162b",
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "name": "stdout",
155
+ "output_type": "stream",
156
+ "text": [
157
+ "Dialogue:\n",
158
+ "Ollie: How is your Hebrew?\r\n",
159
+ "Gabi: Not great. \r\n",
160
+ "Ollie: Could you translate a letter?\r\n",
161
+ "Gabi: From Hebrew to English maybe, the opposite I don’t think so\r\n",
162
+ "Gabi: My writing sucks\r\n",
163
+ "Ollie: Please help me. I don’t have anyone else to ask\r\n",
164
+ "Gabi: Send it to me. I’ll try. \n",
165
+ "\n",
166
+ " -------------------------------------------------- \n",
167
+ "\n",
168
+ "Summary:\n",
169
+ "Gabi knows a bit of Hebrew, though her writing isn't great. She will try to help Ollie translate a letter.\n"
170
+ ]
171
+ }
172
+ ],
173
+ "source": [
174
+ "rand_idx = np.random.randint(0, len(samsum['train']))\n",
175
+ "\n",
176
+ "print(f\"Dialogue:\\n{samsum['train'][rand_idx]['dialogue']}\")\n",
177
+ "print('\\n', '-'*50, '\\n')\n",
178
+ "print(f\"Summary:\\n{samsum['train'][rand_idx]['summary']}\")"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "markdown",
183
+ "id": "8f95359e-c9c4-4ed5-9130-5e2b4a0a83ad",
184
+ "metadata": {},
185
+ "source": [
186
+ "## Preprocessing data"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "id": "50b572e6-b37a-4688-94c9-9c45a2c67c51",
192
+ "metadata": {},
193
+ "source": [
194
+ " I'm using the T5 Transformers model (Text-to-Text Transfer Transformer)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 5,
200
+ "id": "13634dfe-5b1a-4515-9476-8ac0637d0362",
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "model_ckpt = 't5-small'\n",
205
+ "\n",
206
+ "# TODO: Create the Tokenizer AutoTokenizer pretrained checkpoint\n",
207
+ "tokenizer = AutoTokenizer.from_pretrained('t5-small')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 6,
213
+ "id": "6b0be9fc-029b-4057-9d08-29235e5b4573",
214
+ "metadata": {},
215
+ "outputs": [
216
+ {
217
+ "name": "stderr",
218
+ "output_type": "stream",
219
+ "text": [
220
+ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-78c13bd5dd6a016a.arrow\n"
221
+ ]
222
+ },
223
+ {
224
+ "name": "stdout",
225
+ "output_type": "stream",
226
+ "text": [
227
+ "Max source length: 512\n"
228
+ ]
229
+ },
230
+ {
231
+ "data": {
232
+ "application/vnd.jupyter.widget-view+json": {
233
+ "model_id": "",
234
+ "version_major": 2,
235
+ "version_minor": 0
236
+ },
237
+ "text/plain": [
238
+ "Map: 0%| | 0/15551 [00:00<?, ? examples/s]"
239
+ ]
240
+ },
241
+ "metadata": {},
242
+ "output_type": "display_data"
243
+ },
244
+ {
245
+ "name": "stdout",
246
+ "output_type": "stream",
247
+ "text": [
248
+ "Max target length: 95\n"
249
+ ]
250
+ }
251
+ ],
252
+ "source": [
253
+ "from datasets import concatenate_datasets\n",
254
+ "# Find the max lengths of the source and target samples\n",
255
+ "# The maximum total input sequence length after tokenization. \n",
256
+ "# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n",
257
+ "tokenized_inputs = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"dialogue\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n",
258
+ "max_source_length = max([len(x) for x in tokenized_inputs[\"input_ids\"]])\n",
259
+ "print(f\"Max source length: {max_source_length}\")\n",
260
+ "\n",
261
+ "# The maximum total sequence length for target text after tokenization. \n",
262
+ "# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n",
263
+ "tokenized_targets = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"summary\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n",
264
+ "max_target_length = max([len(x) for x in tokenized_targets[\"input_ids\"]])\n",
265
+ "print(f\"Max target length: {max_target_length}\")"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 7,
271
+ "id": "c43b0864-8b92-4cb9-b159-bc8ec15bcc2d",
272
+ "metadata": {},
273
+ "outputs": [
274
+ {
275
+ "name": "stderr",
276
+ "output_type": "stream",
277
+ "text": [
278
+ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-073bbcc8f496f07c.arrow\n"
279
+ ]
280
+ },
281
+ {
282
+ "data": {
283
+ "application/vnd.jupyter.widget-view+json": {
284
+ "model_id": "",
285
+ "version_major": 2,
286
+ "version_minor": 0
287
+ },
288
+ "text/plain": [
289
+ "Map: 0%| | 0/819 [00:00<?, ? examples/s]"
290
+ ]
291
+ },
292
+ "metadata": {},
293
+ "output_type": "display_data"
294
+ },
295
+ {
296
+ "name": "stderr",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-a43b31cabc78c9c3.arrow\n"
300
+ ]
301
+ },
302
+ {
303
+ "name": "stdout",
304
+ "output_type": "stream",
305
+ "text": [
306
+ "Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']\n"
307
+ ]
308
+ }
309
+ ],
310
+ "source": [
311
+ "def preprocess_function(\n",
312
+ " sample, \n",
313
+ " padding=\"max_length\", \n",
314
+ " max_source_length=max_source_length,\n",
315
+ " max_target_length=max_target_length\n",
316
+ "):\n",
317
+ " '''\n",
318
+ " A preprocessing function that will be applied across the dataset.\n",
319
+ " The inputs and targets will be tokenized and padded/truncated to the max lengths.\n",
320
+ "\n",
321
+ " Args:\n",
322
+ " sample: A dictionary containing the source and target texts (keys are \"dialogue\" and \"summary\") in a list.\n",
323
+ " padding: Whether to pad the inputs and targets to the max lengths.\n",
324
+ " max_source_length: The maximum length of the source text.\n",
325
+ " max_target_length: The maximum length of the target text.\n",
326
+ " '''\n",
327
+ " # Add prefix to the input for t5\n",
328
+ " inputs = ['summarize: ' + s for s in sample['dialogue']]\n",
329
+ " \n",
330
+ " # Tokenize inputs, specifying the padding, truncation and max_length\n",
331
+ " model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)\n",
332
+ "\n",
333
+ " # Tokenize targets with the `text_target` keyword argument\n",
334
+ " labels = tokenizer(text_target=sample['summary'], max_length=max_target_length, padding=padding, truncation=True)\n",
335
+ "\n",
336
+ " # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss\n",
337
+ " if padding == \"max_length\":\n",
338
+ " labels[\"input_ids\"] = [\n",
339
+ " [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n",
340
+ " ]\n",
341
+ "\n",
342
+ " # Format and return\n",
343
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
344
+ " return model_inputs\n",
345
+ "\n",
346
+ "# Map this preprocessing function to our datasets using .map on the samsum variable\n",
347
+ "tokenized_dataset = samsum.map(preprocess_function, batched=True, remove_columns=[\"dialogue\", \"summary\", \"id\"])\n",
348
+ "print(f\"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}\")"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": 8,
354
+ "id": "3becd236-0097-4ae5-9bd6-a91ed332e748",
355
+ "metadata": {},
356
+ "outputs": [
357
+ {
358
+ "data": {
359
+ "text/plain": [
360
+ "DatasetDict({\n",
361
+ " train: Dataset({\n",
362
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
363
+ " num_rows: 14732\n",
364
+ " })\n",
365
+ " test: Dataset({\n",
366
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
367
+ " num_rows: 819\n",
368
+ " })\n",
369
+ " validation: Dataset({\n",
370
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
371
+ " num_rows: 818\n",
372
+ " })\n",
373
+ "})"
374
+ ]
375
+ },
376
+ "execution_count": 8,
377
+ "metadata": {},
378
+ "output_type": "execute_result"
379
+ }
380
+ ],
381
+ "source": [
382
+ "tokenized_dataset"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 9,
388
+ "id": "20110839-bb02-4d64-8de7-53253e3f7fe0",
389
+ "metadata": {},
390
+ "outputs": [],
391
+ "source": [
392
+ "metric = evaluate.load(\"rouge\")\n",
393
+ "clear_output()"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 10,
399
+ "id": "ca00f91d-8453-4496-a064-525ef437198f",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "def postprocess_text(preds, labels):\n",
404
+ " '''\n",
405
+ " A simple post-processing function to clean up the predictions and labels\n",
406
+ "\n",
407
+ " Args:\n",
408
+ " preds: List[str] of predictions\n",
409
+ " labels: List[str] of labels\n",
410
+ " '''\n",
411
+ " \n",
412
+ " # strip whitespace on all sentences in preds and labels\n",
413
+ " preds = [p.strip(' ') for p in preds]\n",
414
+ " labels = [l.strip(' ') for l in preds]\n",
415
+ " \n",
416
+ " # rougeLSum expects newline after each sentence\n",
417
+ " preds = [\"\\n\".join(sent_tokenize(pred)) for pred in preds]\n",
418
+ " labels = [\"\\n\".join(sent_tokenize(label)) for label in labels]\n",
419
+ "\n",
420
+ " return preds, labels\n",
421
+ "\n",
422
+ "def compute_metrics(eval_preds):\n",
423
+ " \n",
424
+ " # Fetch the predictions and labels\n",
425
+ " preds, labels = eval_preds\n",
426
+ " if isinstance(preds, tuple):\n",
427
+ " preds = preds[0]\n",
428
+ " \n",
429
+ " # Decode the predictions back to text\n",
430
+ " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
431
+ " \n",
432
+ " # Replace -100 in the labels as we can't decode them.\n",
433
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
434
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
435
+ "\n",
436
+ " # Some simple post-processing for ROUGE\n",
437
+ " decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n",
438
+ "\n",
439
+ " # Compute ROUGE on the decoded predictions and the decoder labels\n",
440
+ " result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n",
441
+ " \n",
442
+ " result = {k: round(v * 100, 4) for k, v in result.items()}\n",
443
+ " prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]\n",
444
+ " result[\"gen_len\"] = np.mean(prediction_lens)\n",
445
+ " return result"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "markdown",
450
+ "id": "7b244846-2ebf-4019-a577-3ef07e350f7c",
451
+ "metadata": {},
452
+ "source": [
453
+ "## Creating the model"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": 11,
459
+ "id": "49c1ac7c-6400-4a67-b32b-5bdc7330d790",
460
+ "metadata": {},
461
+ "outputs": [],
462
+ "source": [
463
+ "# the AutoModelForSeq2SeqLM class and use the model_ckpt variable)\n",
464
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)\n",
465
+ "\n",
466
+ "clear_output()"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": 12,
472
+ "id": "e027b290-c04f-4241-b238-41787f32abe0",
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": [
476
+ "# we want to ignore tokenizer pad token in the loss\n",
477
+ "label_pad_token_id = -100\n",
478
+ "\n",
479
+ "# Data Collator, specifying the tokenizer, model, and label_pad_token_id\n",
480
+ "# pad_to_multiple_of=8 to speed up training\n",
481
+ "data_collator = transformers.DataCollatorForSeq2Seq(\n",
482
+ " tokenizer,\n",
483
+ " model=model,\n",
484
+ " label_pad_token_id=label_pad_token_id,\n",
485
+ " pad_to_multiple_of=8\n",
486
+ ")"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": 13,
492
+ "id": "0d20ee86-ac8c-4ae7-9e7c-92283e879e00",
493
+ "metadata": {},
494
+ "outputs": [
495
+ {
496
+ "name": "stderr",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
500
+ ]
501
+ }
502
+ ],
503
+ "source": [
504
+ "import logging\n",
505
+ "logging.getLogger(\"transformers\").setLevel(logging.WARNING)\n",
506
+ "\n",
507
+ "\n",
508
+ "# Define training hyperparameters in Seq2SeqTrainingArguments\n",
509
+ "training_args = Seq2SeqTrainingArguments(\n",
510
+ " output_dir=\"./t5_samsum\", # the output directory\n",
511
+ " logging_strategy=\"epoch\",\n",
512
+ " save_strategy=\"epoch\",\n",
513
+ " evaluation_strategy=\"epoch\",\n",
514
+ " learning_rate=2e-5,\n",
515
+ " num_train_epochs=5,\n",
516
+ " predict_with_generate=True,\n",
517
+ " per_device_train_batch_size=8,\n",
518
+ " per_device_eval_batch_size=8,\n",
519
+ " weight_decay=0.01,\n",
520
+ " load_best_model_at_end=True,\n",
521
+ " logging_steps=50,\n",
522
+ " logging_first_step=False,\n",
523
+ " fp16=False\n",
524
+ ")\n",
525
+ "\n",
526
+ "# index into the tokenized_dataset variable to get the training and validation data\n",
527
+ "training_data = tokenized_dataset['train']\n",
528
+ "eval_data = tokenized_dataset['validation']\n",
529
+ "\n",
530
+ "# Create the Trainer for the model\n",
531
+ "trainer = Seq2SeqTrainer(\n",
532
+ " model=model, # the model to be trained\n",
533
+ " args=training_args, # training arguments\n",
534
+ " train_dataset=training_data, # the training dataset\n",
535
+ " eval_dataset=eval_data, # the validation dataset\n",
536
+ " tokenizer=tokenizer, # the tokenizer we used to tokenize our data\n",
537
+ " compute_metrics=compute_metrics, # the function we defined above to compute metrics\n",
538
+ " data_collator=data_collator # the data collator we defined above\n",
539
+ ")"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": 14,
545
+ "id": "a3b5f21d-b4cb-4f8b-a7fc-cf132ef43c65",
546
+ "metadata": {},
547
+ "outputs": [
548
+ {
549
+ "name": "stdout",
550
+ "output_type": "stream",
551
+ "text": [
552
+ "TrainOutput(global_step=9210, training_loss=1.9861197174436753, metrics={'train_runtime': 3551.1547, 'train_samples_per_second': 20.743, 'train_steps_per_second': 2.594, 'total_flos': 9969277096427520.0, 'train_loss': 1.9861197174436753, 'epoch': 5.0})\n"
553
+ ]
554
+ }
555
+ ],
556
+ "source": [
557
+ "# Train the model (this will take a while!)\n",
558
+ "results = trainer.train()\n",
559
+ "clear_output()\n",
560
+ "pprint(results)"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "markdown",
565
+ "id": "ddf8c308",
566
+ "metadata": {},
567
+ "source": [
568
+ "## Evaluating the model"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": 15,
574
+ "id": "03e94a7f-2d26-48eb-ab17-cb58b14b93f3",
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "res = trainer.evaluate()\n",
579
+ "clear_output()"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": 18,
585
+ "id": "23675ccb-071c-4a4f-8e42-1a71dc628a5c",
586
+ "metadata": {},
587
+ "outputs": [
588
+ {
589
+ "data": {
590
+ "text/html": [
591
+ "<div>\n",
592
+ "<style scoped>\n",
593
+ " .dataframe tbody tr th:only-of-type {\n",
594
+ " vertical-align: middle;\n",
595
+ " }\n",
596
+ "\n",
597
+ " .dataframe tbody tr th {\n",
598
+ " vertical-align: top;\n",
599
+ " }\n",
600
+ "\n",
601
+ " .dataframe thead th {\n",
602
+ " text-align: right;\n",
603
+ " }\n",
604
+ "</style>\n",
605
+ "<table border=\"1\" class=\"dataframe\">\n",
606
+ " <thead>\n",
607
+ " <tr style=\"text-align: right;\">\n",
608
+ " <th></th>\n",
609
+ " <th>eval_loss</th>\n",
610
+ " <th>eval_rouge1</th>\n",
611
+ " <th>eval_rouge2</th>\n",
612
+ " <th>eval_rougeL</th>\n",
613
+ " <th>eval_rougeLsum</th>\n",
614
+ " </tr>\n",
615
+ " </thead>\n",
616
+ " <tbody>\n",
617
+ " <tr>\n",
618
+ " <th>t5-small</th>\n",
619
+ " <td>1.764253</td>\n",
620
+ " <td>100.0</td>\n",
621
+ " <td>100.0</td>\n",
622
+ " <td>100.0</td>\n",
623
+ " <td>100.0</td>\n",
624
+ " </tr>\n",
625
+ " </tbody>\n",
626
+ "</table>\n",
627
+ "</div>"
628
+ ],
629
+ "text/plain": [
630
+ " eval_loss eval_rouge1 eval_rouge2 eval_rougeL eval_rougeLsum\n",
631
+ "t5-small 1.764253 100.0 100.0 100.0 100.0"
632
+ ]
633
+ },
634
+ "execution_count": 18,
635
+ "metadata": {},
636
+ "output_type": "execute_result"
637
+ }
638
+ ],
639
+ "source": [
640
+ "cols = [\"eval_loss\", \"eval_rouge1\", \"eval_rouge2\", \"eval_rougeL\", \"eval_rougeLsum\"]\n",
641
+ "filtered_scores = dict((x , res[x]) for x in cols)\n",
642
+ "pd.DataFrame([filtered_scores], index=[model_ckpt])"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": 20,
648
+ "id": "7c59a731",
649
+ "metadata": {},
650
+ "outputs": [],
651
+ "source": [
652
+ "from transformers import pipeline\n",
653
+ "\n",
654
+ "summarizer_pipeline = pipeline(\"summarization\",\n",
655
+ " model=model,\n",
656
+ " tokenizer=tokenizer,\n",
657
+ " device=0)"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": 22,
663
+ "id": "5138f2bc",
664
+ "metadata": {},
665
+ "outputs": [
666
+ {
667
+ "name": "stdout",
668
+ "output_type": "stream",
669
+ "text": [
670
+ "Dialogue: Adelina: Hi handsome. Where you you come from?\r\n",
671
+ "Cyprien: What do you mean?\r\n",
672
+ "Adelina: What do you mean, \"what do you mean\"? It's a simple question, where do you come from?\r\n",
673
+ "Cyprien: Well I was born in Jarrow, live in London now, so you could say I came from either of those places\r\n",
674
+ "Cyprien: I was educated in Loughborouogh, so in a sense I came from there.\r\n",
675
+ "Adelina: OK. \r\n",
676
+ "Cyprien: In another sense I come from my mother's vagina, but I dare say everyone can say that.\r\n",
677
+ "Adelina: Are you all right?\r\n",
678
+ "Cyprien: IN another sense I come from the atoms in the air that I breath or the food I eat, which comes to me from many places, so all I can say is \"I come from Planet Earth\".\r\n",
679
+ "Adelina: OK, bye. If you're gonna be a dick...\r\n",
680
+ "Cyprien: Wait, what you got against earthlings?\n",
681
+ "-------------------------\n",
682
+ "True Summary: Cyprien irritates Adelina by giving too many responses.\n",
683
+ "-------------------------\n",
684
+ "Model Summary: Cyprien came from Jarrow, live in London. She came from Loughborouogh, and came from her mother's vagina.\n",
685
+ "-------------------------\n"
686
+ ]
687
+ }
688
+ ],
689
+ "source": [
690
+ "rand_idx = np.random.randint(low=0, high=len(samsum[\"test\"]))\n",
691
+ "sample = samsum[\"test\"][rand_idx]\n",
692
+ "\n",
693
+ "dialog = sample[\"dialogue\"]\n",
694
+ "true_summary = sample[\"summary\"]\n",
695
+ "\n",
696
+ "model_summary = summarizer_pipeline(dialog)\n",
697
+ "clear_output()\n",
698
+ "\n",
699
+ "print(f\"Dialogue: {dialog}\")\n",
700
+ "print(\"-\"*25)\n",
701
+ "print(f\"True Summary: {true_summary}\")\n",
702
+ "print(\"-\"*25)\n",
703
+ "print(f\"Model Summary: {model_summary[0]['summary_text']}\")\n",
704
+ "print(\"-\"*25)"
705
+ ]
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": 24,
710
+ "id": "f051655f",
711
+ "metadata": {},
712
+ "outputs": [
713
+ {
714
+ "name": "stderr",
715
+ "output_type": "stream",
716
+ "text": [
717
+ "Your max_length is set to 200, but you input_length is only 94. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=47)\n"
718
+ ]
719
+ },
720
+ {
721
+ "name": "stdout",
722
+ "output_type": "stream",
723
+ "text": [
724
+ "Original Text:\n",
725
+ "\n",
726
+ "Andy: I need you to come in to work on the weekend.\n",
727
+ "David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n",
728
+ "Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n",
729
+ "David: But I already made plans and this is news to me on very short notice.\n",
730
+ "Andy: Be there or you'r fired\n",
731
+ "\n",
732
+ "\n",
733
+ " -------------------------------------------------- \n",
734
+ "\n",
735
+ "Generated Summary: \n",
736
+ "[{'summary_text': 'David has plans to go on a concert. Andy needs to get his paperwork all sorted out for this year. David already made plans.'}]\n"
737
+ ]
738
+ }
739
+ ],
740
+ "source": [
741
+ "def create_summary(input_text, model_pipeline=summarizer_pipeline):\n",
742
+ " summary = model_pipeline(input_text)\n",
743
+ " return summary\n",
744
+ "\n",
745
+ "text = '''\n",
746
+ "Andy: I need you to come in to work on the weekend.\n",
747
+ "David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n",
748
+ "Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n",
749
+ "David: But I already made plans and this is news to me on very short notice.\n",
750
+ "Andy: Be there or you'r fired\n",
751
+ "'''\n",
752
+ "\n",
753
+ "print(f\"Original Text:\\n{text}\")\n",
754
+ "print('\\n', '-'*50, '\\n')\n",
755
+ "\n",
756
+ "summary = create_summary(text)\n",
757
+ "\n",
758
+ "print(f\"Generated Summary: \\n{summary}\")"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "code",
763
+ "execution_count": null,
764
+ "id": "ad5d29a0",
765
+ "metadata": {},
766
+ "outputs": [],
767
+ "source": []
768
+ }
769
+ ],
770
+ "metadata": {
771
+ "kernelspec": {
772
+ "display_name": "Python 3 (ipykernel)",
773
+ "language": "python",
774
+ "name": "python3"
775
+ },
776
+ "language_info": {
777
+ "codemirror_mode": {
778
+ "name": "ipython",
779
+ "version": 3
780
+ },
781
+ "file_extension": ".py",
782
+ "mimetype": "text/x-python",
783
+ "name": "python",
784
+ "nbconvert_exporter": "python",
785
+ "pygments_lexer": "ipython3",
786
+ "version": "3.9.0"
787
+ }
788
+ },
789
+ "nbformat": 4,
790
+ "nbformat_minor": 5
791
+ }
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from tf_model_api.model_api import ModelAPI
3
+
4
+ app = Flask(__name__)
5
+ # Create the model class object
6
+ summarizer_model = ModelAPI()
7
+
8
+ @app.route('/')
9
+ def index():
10
+ data = {
11
+ 'prompts': ''
12
+ }
13
+ return render_template('index.html', data=data)
14
+
15
+ @app.route('/create-summary', methods=['POST'])
16
+ def creat_summary_response():
17
+ """
18
+ create a summary using the input received
19
+ from the user.
20
+ """
21
+
22
+ data = request.get_json() # Extract the JSON data from the request
23
+ text = data.get('text') # Get the 'text' field from the JSON data
24
+
25
+ summary = summarizer_model.get_summary(text)
26
+ if summary:
27
+ result = {
28
+ 'status': 'success',
29
+ 'result': summary}
30
+ return jsonify(result), 200
31
+ else:
32
+ result = {
33
+ 'status': 'fail'
34
+ }
35
+ return jsonify(result), 400
36
+
37
+ if __name__ == '__main__':
38
+ app.run()
requirements.txt ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ accelerate==0.23.0
3
+ aiohttp==3.8.5
4
+ aiosignal==1.3.1
5
+ anyio==4.0.0
6
+ appnope==0.1.3
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.2.3
10
+ asttokens==2.4.0
11
+ async-lru==2.0.4
12
+ async-timeout==4.0.3
13
+ attrs==23.1.0
14
+ Babel==2.12.1
15
+ backcall==0.2.0
16
+ beautifulsoup4==4.12.2
17
+ bleach==6.0.0
18
+ blinker==1.6.2
19
+ Brotli==1.1.0
20
+ certifi==2023.7.22
21
+ cffi==1.15.1
22
+ charset-normalizer==3.2.0
23
+ click==8.1.7
24
+ comm==0.1.4
25
+ contourpy==1.1.1
26
+ cycler==0.11.0
27
+ datasets==2.14.5
28
+ debugpy==1.8.0
29
+ decorator==5.1.1
30
+ defusedxml==0.7.1
31
+ dill==0.3.7
32
+ evaluate==0.4.0
33
+ executing==1.2.0
34
+ fastjsonschema==2.18.0
35
+ filelock==3.12.4
36
+ Flask==2.3.3
37
+ fonttools==4.42.1
38
+ fqdn==1.5.1
39
+ frozenlist==1.4.0
40
+ fsspec==2023.6.0
41
+ huggingface-hub==0.17.3
42
+ idna==3.4
43
+ inflate64==0.3.1
44
+ ipykernel==6.25.2
45
+ ipython==8.15.0
46
+ ipython-genutils==0.2.0
47
+ ipywidgets==8.1.1
48
+ isoduration==20.11.0
49
+ itsdangerous==2.1.2
50
+ jedi==0.19.0
51
+ Jinja2==3.1.2
52
+ joblib==1.3.2
53
+ json5==0.9.14
54
+ jsonpointer==2.4
55
+ jsonschema==4.19.1
56
+ jsonschema-specifications==2023.7.1
57
+ jupyter==1.0.0
58
+ jupyter-console==6.6.3
59
+ jupyter-events==0.7.0
60
+ jupyter-lsp==2.2.0
61
+ jupyter_client==8.3.1
62
+ jupyter_core==5.3.2
63
+ jupyter_server==2.7.3
64
+ jupyter_server_terminals==0.4.4
65
+ jupyterlab==4.0.6
66
+ jupyterlab-pygments==0.2.2
67
+ jupyterlab-widgets==3.0.9
68
+ jupyterlab_server==2.25.0
69
+ kiwisolver==1.4.5
70
+ MarkupSafe==2.1.3
71
+ matplotlib==3.8.0
72
+ matplotlib-inline==0.1.6
73
+ mistune==3.0.1
74
+ mpmath==1.3.0
75
+ multidict==6.0.4
76
+ multiprocess==0.70.15
77
+ multivolumefile==0.2.3
78
+ nbclient==0.8.0
79
+ nbconvert==7.8.0
80
+ nbformat==5.9.2
81
+ nest-asyncio==1.5.8
82
+ networkx==3.1
83
+ nltk==3.8.1
84
+ notebook==7.0.4
85
+ notebook_shim==0.2.3
86
+ numpy==1.26.0
87
+ overrides==7.4.0
88
+ packaging==23.1
89
+ pandas==2.1.1
90
+ pandocfilters==1.5.0
91
+ parso==0.8.3
92
+ pexpect==4.8.0
93
+ pickleshare==0.7.5
94
+ Pillow==10.0.1
95
+ platformdirs==3.10.0
96
+ prometheus-client==0.17.1
97
+ prompt-toolkit==3.0.39
98
+ psutil==5.9.5
99
+ ptyprocess==0.7.0
100
+ pure-eval==0.2.2
101
+ py7zr==0.20.6
102
+ pyarrow==13.0.0
103
+ pybcj==1.0.1
104
+ pycparser==2.21
105
+ pycryptodomex==3.19.0
106
+ Pygments==2.16.1
107
+ pyparsing==3.1.1
108
+ pyppmd==1.0.0
109
+ python-dateutil==2.8.2
110
+ python-json-logger==2.0.7
111
+ pytz==2023.3.post1
112
+ PyYAML==6.0.1
113
+ pyzmq==25.1.1
114
+ pyzstd==0.15.9
115
+ qtconsole==5.4.4
116
+ QtPy==2.4.0
117
+ referencing==0.30.2
118
+ regex==2023.8.8
119
+ requests==2.31.0
120
+ responses==0.18.0
121
+ rfc3339-validator==0.1.4
122
+ rfc3986-validator==0.1.1
123
+ rouge-score==0.1.2
124
+ rpds-py==0.10.3
125
+ safetensors==0.3.3
126
+ seaborn==0.12.2
127
+ Send2Trash==1.8.2
128
+ six==1.16.0
129
+ sniffio==1.3.0
130
+ soupsieve==2.5
131
+ stack-data==0.6.2
132
+ sympy==1.12
133
+ terminado==0.17.1
134
+ texttable==1.6.7
135
+ tinycss2==1.2.1
136
+ tokenizers==0.13.3
137
+ torch==2.0.1
138
+ torchaudio==2.0.2
139
+ torchvision==0.15.2
140
+ tornado==6.3.3
141
+ tqdm==4.66.1
142
+ traitlets==5.10.1
143
+ transformers==4.33.3
144
+ typing_extensions==4.8.0
145
+ tzdata==2023.3
146
+ uri-template==1.3.0
147
+ urllib3==2.0.5
148
+ wcwidth==0.2.6
149
+ webcolors==1.13
150
+ webencodings==0.5.1
151
+ websocket-client==1.6.3
152
+ Werkzeug==2.3.7
153
+ widgetsnbextension==4.0.9
154
+ xxhash==3.3.0
155
+ yarl==1.9.2