Spaces:
Running
Running
Commit
·
19f61e3
1
Parent(s):
fe64f77
Upload 6 files
Browse files- .DS_Store +0 -0
- .gitignore +3 -0
- Text_Summarization_T5.ipynb +791 -0
- app.py +38 -0
- requirements.txt +155 -0
.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.bin
|
3 |
+
*.pt
|
Text_Summarization_T5.ipynb
ADDED
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "c08e675e-437e-4e7d-baee-bd55dda74611",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Abstractive Text Summarization with T5\n",
|
9 |
+
"\n",
|
10 |
+
"This implementation uses HuggingFace, especially utilizing `AutoModelForSeq2SeqLM` and `AutoTokenizer`. "
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"id": "a910e4b5-040d-4499-b5c2-32f3e1ac1c34",
|
16 |
+
"metadata": {},
|
17 |
+
"source": [
|
18 |
+
"## Importing libraries"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"id": "d22ee5a9-1981-4883-a926-db37905ec8b6",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"Setup done!\n"
|
32 |
+
]
|
33 |
+
}
|
34 |
+
],
|
35 |
+
"source": [
|
36 |
+
"# Installs\n",
|
37 |
+
"!pip install -q evaluate py7zr rouge_score absl-py\n",
|
38 |
+
"\n",
|
39 |
+
"# Imports here\n",
|
40 |
+
"import numpy as np\n",
|
41 |
+
"import pandas as pd\n",
|
42 |
+
"import matplotlib.pyplot as plt\n",
|
43 |
+
"import seaborn as sns\n",
|
44 |
+
"import nltk\n",
|
45 |
+
"from nltk.tokenize import sent_tokenize\n",
|
46 |
+
"nltk.download(\"punkt\")\n",
|
47 |
+
"\n",
|
48 |
+
"import torch\n",
|
49 |
+
"import torch.nn as nn\n",
|
50 |
+
"\n",
|
51 |
+
"import datasets\n",
|
52 |
+
"import transformers\n",
|
53 |
+
"from transformers import (\n",
|
54 |
+
" AutoModelForSeq2SeqLM,\n",
|
55 |
+
" Seq2SeqTrainingArguments,\n",
|
56 |
+
" Seq2SeqTrainer,\n",
|
57 |
+
" AutoTokenizer\n",
|
58 |
+
")\n",
|
59 |
+
"import evaluate\n",
|
60 |
+
"\n",
|
61 |
+
"# Quality of life fixes\n",
|
62 |
+
"import warnings\n",
|
63 |
+
"warnings.filterwarnings('ignore')\n",
|
64 |
+
"from pprint import pprint\n",
|
65 |
+
"\n",
|
66 |
+
"import os\n",
|
67 |
+
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
|
68 |
+
"\n",
|
69 |
+
"from IPython.display import clear_output\n",
|
70 |
+
"\n",
|
71 |
+
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
72 |
+
"print(f\"Transformers version: {transformers.__version__}\")\n",
|
73 |
+
"print(f\"Datasets version: {datasets.__version__}\")\n",
|
74 |
+
"print(f\"Evaluate version: {evaluate.__version__}\")\n",
|
75 |
+
"\n",
|
76 |
+
"# Get the samsum dataset\n",
|
77 |
+
"samsum = datasets.load_dataset('samsum')\n",
|
78 |
+
"clear_output()\n",
|
79 |
+
"print(\"Setup done!\")"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 2,
|
85 |
+
"id": "bafa753c-0746-4ece-b5eb-4511c9138b09",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [
|
88 |
+
{
|
89 |
+
"data": {
|
90 |
+
"text/plain": [
|
91 |
+
"'4.27.4'"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
"execution_count": 2,
|
95 |
+
"metadata": {},
|
96 |
+
"output_type": "execute_result"
|
97 |
+
}
|
98 |
+
],
|
99 |
+
"source": [
|
100 |
+
"# Verify transformers version\n",
|
101 |
+
"transformers.__version__"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "markdown",
|
106 |
+
"id": "f15204cc-0f21-4dc9-a8e4-429c57b227a9",
|
107 |
+
"metadata": {},
|
108 |
+
"source": [
|
109 |
+
"## Playing around with the dataset"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": 3,
|
115 |
+
"id": "ba5c1425-a776-4201-97e2-bd420ec112fe",
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [
|
118 |
+
{
|
119 |
+
"data": {
|
120 |
+
"text/plain": [
|
121 |
+
"DatasetDict({\n",
|
122 |
+
" train: Dataset({\n",
|
123 |
+
" features: ['id', 'dialogue', 'summary'],\n",
|
124 |
+
" num_rows: 14732\n",
|
125 |
+
" })\n",
|
126 |
+
" test: Dataset({\n",
|
127 |
+
" features: ['id', 'dialogue', 'summary'],\n",
|
128 |
+
" num_rows: 819\n",
|
129 |
+
" })\n",
|
130 |
+
" validation: Dataset({\n",
|
131 |
+
" features: ['id', 'dialogue', 'summary'],\n",
|
132 |
+
" num_rows: 818\n",
|
133 |
+
" })\n",
|
134 |
+
"})"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
"execution_count": 3,
|
138 |
+
"metadata": {},
|
139 |
+
"output_type": "execute_result"
|
140 |
+
}
|
141 |
+
],
|
142 |
+
"source": [
|
143 |
+
"# The samsum dataset shape\n",
|
144 |
+
"samsum"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"cell_type": "code",
|
149 |
+
"execution_count": 4,
|
150 |
+
"id": "5d53736c-a8c7-4fe3-b8f1-566c1d99162b",
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [
|
153 |
+
{
|
154 |
+
"name": "stdout",
|
155 |
+
"output_type": "stream",
|
156 |
+
"text": [
|
157 |
+
"Dialogue:\n",
|
158 |
+
"Ollie: How is your Hebrew?\r\n",
|
159 |
+
"Gabi: Not great. \r\n",
|
160 |
+
"Ollie: Could you translate a letter?\r\n",
|
161 |
+
"Gabi: From Hebrew to English maybe, the opposite I don’t think so\r\n",
|
162 |
+
"Gabi: My writing sucks\r\n",
|
163 |
+
"Ollie: Please help me. I don’t have anyone else to ask\r\n",
|
164 |
+
"Gabi: Send it to me. I’ll try. \n",
|
165 |
+
"\n",
|
166 |
+
" -------------------------------------------------- \n",
|
167 |
+
"\n",
|
168 |
+
"Summary:\n",
|
169 |
+
"Gabi knows a bit of Hebrew, though her writing isn't great. She will try to help Ollie translate a letter.\n"
|
170 |
+
]
|
171 |
+
}
|
172 |
+
],
|
173 |
+
"source": [
|
174 |
+
"rand_idx = np.random.randint(0, len(samsum['train']))\n",
|
175 |
+
"\n",
|
176 |
+
"print(f\"Dialogue:\\n{samsum['train'][rand_idx]['dialogue']}\")\n",
|
177 |
+
"print('\\n', '-'*50, '\\n')\n",
|
178 |
+
"print(f\"Summary:\\n{samsum['train'][rand_idx]['summary']}\")"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "markdown",
|
183 |
+
"id": "8f95359e-c9c4-4ed5-9130-5e2b4a0a83ad",
|
184 |
+
"metadata": {},
|
185 |
+
"source": [
|
186 |
+
"## Preprocessing data"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "markdown",
|
191 |
+
"id": "50b572e6-b37a-4688-94c9-9c45a2c67c51",
|
192 |
+
"metadata": {},
|
193 |
+
"source": [
|
194 |
+
" I'm using the T5 Transformers model (Text-to-Text Transfer Transformer)"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": 5,
|
200 |
+
"id": "13634dfe-5b1a-4515-9476-8ac0637d0362",
|
201 |
+
"metadata": {},
|
202 |
+
"outputs": [],
|
203 |
+
"source": [
|
204 |
+
"model_ckpt = 't5-small'\n",
|
205 |
+
"\n",
|
206 |
+
"# TODO: Create the Tokenizer AutoTokenizer pretrained checkpoint\n",
|
207 |
+
"tokenizer = AutoTokenizer.from_pretrained('t5-small')"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 6,
|
213 |
+
"id": "6b0be9fc-029b-4057-9d08-29235e5b4573",
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [
|
216 |
+
{
|
217 |
+
"name": "stderr",
|
218 |
+
"output_type": "stream",
|
219 |
+
"text": [
|
220 |
+
"Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-78c13bd5dd6a016a.arrow\n"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"name": "stdout",
|
225 |
+
"output_type": "stream",
|
226 |
+
"text": [
|
227 |
+
"Max source length: 512\n"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"data": {
|
232 |
+
"application/vnd.jupyter.widget-view+json": {
|
233 |
+
"model_id": "",
|
234 |
+
"version_major": 2,
|
235 |
+
"version_minor": 0
|
236 |
+
},
|
237 |
+
"text/plain": [
|
238 |
+
"Map: 0%| | 0/15551 [00:00<?, ? examples/s]"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
"metadata": {},
|
242 |
+
"output_type": "display_data"
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"name": "stdout",
|
246 |
+
"output_type": "stream",
|
247 |
+
"text": [
|
248 |
+
"Max target length: 95\n"
|
249 |
+
]
|
250 |
+
}
|
251 |
+
],
|
252 |
+
"source": [
|
253 |
+
"from datasets import concatenate_datasets\n",
|
254 |
+
"# Find the max lengths of the source and target samples\n",
|
255 |
+
"# The maximum total input sequence length after tokenization. \n",
|
256 |
+
"# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n",
|
257 |
+
"tokenized_inputs = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"dialogue\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n",
|
258 |
+
"max_source_length = max([len(x) for x in tokenized_inputs[\"input_ids\"]])\n",
|
259 |
+
"print(f\"Max source length: {max_source_length}\")\n",
|
260 |
+
"\n",
|
261 |
+
"# The maximum total sequence length for target text after tokenization. \n",
|
262 |
+
"# Sequences that are longer than this will be truncated, sequences shorter are be padded.\n",
|
263 |
+
"tokenized_targets = concatenate_datasets([samsum[\"train\"], samsum[\"test\"]]).map(lambda x: tokenizer(x[\"summary\"], truncation=True), batched=True, remove_columns=[\"dialogue\", \"summary\"])\n",
|
264 |
+
"max_target_length = max([len(x) for x in tokenized_targets[\"input_ids\"]])\n",
|
265 |
+
"print(f\"Max target length: {max_target_length}\")"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": 7,
|
271 |
+
"id": "c43b0864-8b92-4cb9-b159-bc8ec15bcc2d",
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [
|
274 |
+
{
|
275 |
+
"name": "stderr",
|
276 |
+
"output_type": "stream",
|
277 |
+
"text": [
|
278 |
+
"Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-073bbcc8f496f07c.arrow\n"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"data": {
|
283 |
+
"application/vnd.jupyter.widget-view+json": {
|
284 |
+
"model_id": "",
|
285 |
+
"version_major": 2,
|
286 |
+
"version_minor": 0
|
287 |
+
},
|
288 |
+
"text/plain": [
|
289 |
+
"Map: 0%| | 0/819 [00:00<?, ? examples/s]"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
"metadata": {},
|
293 |
+
"output_type": "display_data"
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"name": "stderr",
|
297 |
+
"output_type": "stream",
|
298 |
+
"text": [
|
299 |
+
"Loading cached processed dataset at C:\\Users\\QXLVR\\.cache\\huggingface\\datasets\\samsum\\samsum\\0.0.0\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e\\cache-a43b31cabc78c9c3.arrow\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"name": "stdout",
|
304 |
+
"output_type": "stream",
|
305 |
+
"text": [
|
306 |
+
"Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']\n"
|
307 |
+
]
|
308 |
+
}
|
309 |
+
],
|
310 |
+
"source": [
|
311 |
+
"def preprocess_function(\n",
|
312 |
+
" sample, \n",
|
313 |
+
" padding=\"max_length\", \n",
|
314 |
+
" max_source_length=max_source_length,\n",
|
315 |
+
" max_target_length=max_target_length\n",
|
316 |
+
"):\n",
|
317 |
+
" '''\n",
|
318 |
+
" A preprocessing function that will be applied across the dataset.\n",
|
319 |
+
" The inputs and targets will be tokenized and padded/truncated to the max lengths.\n",
|
320 |
+
"\n",
|
321 |
+
" Args:\n",
|
322 |
+
" sample: A dictionary containing the source and target texts (keys are \"dialogue\" and \"summary\") in a list.\n",
|
323 |
+
" padding: Whether to pad the inputs and targets to the max lengths.\n",
|
324 |
+
" max_source_length: The maximum length of the source text.\n",
|
325 |
+
" max_target_length: The maximum length of the target text.\n",
|
326 |
+
" '''\n",
|
327 |
+
" # Add prefix to the input for t5\n",
|
328 |
+
" inputs = ['summarize: ' + s for s in sample['dialogue']]\n",
|
329 |
+
" \n",
|
330 |
+
" # Tokenize inputs, specifying the padding, truncation and max_length\n",
|
331 |
+
" model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)\n",
|
332 |
+
"\n",
|
333 |
+
" # Tokenize targets with the `text_target` keyword argument\n",
|
334 |
+
" labels = tokenizer(text_target=sample['summary'], max_length=max_target_length, padding=padding, truncation=True)\n",
|
335 |
+
"\n",
|
336 |
+
" # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore padding in the loss\n",
|
337 |
+
" if padding == \"max_length\":\n",
|
338 |
+
" labels[\"input_ids\"] = [\n",
|
339 |
+
" [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n",
|
340 |
+
" ]\n",
|
341 |
+
"\n",
|
342 |
+
" # Format and return\n",
|
343 |
+
" model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
|
344 |
+
" return model_inputs\n",
|
345 |
+
"\n",
|
346 |
+
"# Map this preprocessing function to our datasets using .map on the samsum variable\n",
|
347 |
+
"tokenized_dataset = samsum.map(preprocess_function, batched=True, remove_columns=[\"dialogue\", \"summary\", \"id\"])\n",
|
348 |
+
"print(f\"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}\")"
|
349 |
+
]
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"cell_type": "code",
|
353 |
+
"execution_count": 8,
|
354 |
+
"id": "3becd236-0097-4ae5-9bd6-a91ed332e748",
|
355 |
+
"metadata": {},
|
356 |
+
"outputs": [
|
357 |
+
{
|
358 |
+
"data": {
|
359 |
+
"text/plain": [
|
360 |
+
"DatasetDict({\n",
|
361 |
+
" train: Dataset({\n",
|
362 |
+
" features: ['input_ids', 'attention_mask', 'labels'],\n",
|
363 |
+
" num_rows: 14732\n",
|
364 |
+
" })\n",
|
365 |
+
" test: Dataset({\n",
|
366 |
+
" features: ['input_ids', 'attention_mask', 'labels'],\n",
|
367 |
+
" num_rows: 819\n",
|
368 |
+
" })\n",
|
369 |
+
" validation: Dataset({\n",
|
370 |
+
" features: ['input_ids', 'attention_mask', 'labels'],\n",
|
371 |
+
" num_rows: 818\n",
|
372 |
+
" })\n",
|
373 |
+
"})"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
"execution_count": 8,
|
377 |
+
"metadata": {},
|
378 |
+
"output_type": "execute_result"
|
379 |
+
}
|
380 |
+
],
|
381 |
+
"source": [
|
382 |
+
"tokenized_dataset"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": 9,
|
388 |
+
"id": "20110839-bb02-4d64-8de7-53253e3f7fe0",
|
389 |
+
"metadata": {},
|
390 |
+
"outputs": [],
|
391 |
+
"source": [
|
392 |
+
"metric = evaluate.load(\"rouge\")\n",
|
393 |
+
"clear_output()"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"cell_type": "code",
|
398 |
+
"execution_count": 10,
|
399 |
+
"id": "ca00f91d-8453-4496-a064-525ef437198f",
|
400 |
+
"metadata": {},
|
401 |
+
"outputs": [],
|
402 |
+
"source": [
|
403 |
+
"def postprocess_text(preds, labels):\n",
|
404 |
+
" '''\n",
|
405 |
+
" A simple post-processing function to clean up the predictions and labels\n",
|
406 |
+
"\n",
|
407 |
+
" Args:\n",
|
408 |
+
" preds: List[str] of predictions\n",
|
409 |
+
" labels: List[str] of labels\n",
|
410 |
+
" '''\n",
|
411 |
+
" \n",
|
412 |
+
" # strip whitespace on all sentences in preds and labels\n",
|
413 |
+
" preds = [p.strip(' ') for p in preds]\n",
|
414 |
+
" labels = [l.strip(' ') for l in preds]\n",
|
415 |
+
" \n",
|
416 |
+
" # rougeLSum expects newline after each sentence\n",
|
417 |
+
" preds = [\"\\n\".join(sent_tokenize(pred)) for pred in preds]\n",
|
418 |
+
" labels = [\"\\n\".join(sent_tokenize(label)) for label in labels]\n",
|
419 |
+
"\n",
|
420 |
+
" return preds, labels\n",
|
421 |
+
"\n",
|
422 |
+
"def compute_metrics(eval_preds):\n",
|
423 |
+
" \n",
|
424 |
+
" # Fetch the predictions and labels\n",
|
425 |
+
" preds, labels = eval_preds\n",
|
426 |
+
" if isinstance(preds, tuple):\n",
|
427 |
+
" preds = preds[0]\n",
|
428 |
+
" \n",
|
429 |
+
" # Decode the predictions back to text\n",
|
430 |
+
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
431 |
+
" \n",
|
432 |
+
" # Replace -100 in the labels as we can't decode them.\n",
|
433 |
+
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
434 |
+
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
435 |
+
"\n",
|
436 |
+
" # Some simple post-processing for ROUGE\n",
|
437 |
+
" decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)\n",
|
438 |
+
"\n",
|
439 |
+
" # Compute ROUGE on the decoded predictions and the decoder labels\n",
|
440 |
+
" result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n",
|
441 |
+
" \n",
|
442 |
+
" result = {k: round(v * 100, 4) for k, v in result.items()}\n",
|
443 |
+
" prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]\n",
|
444 |
+
" result[\"gen_len\"] = np.mean(prediction_lens)\n",
|
445 |
+
" return result"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "markdown",
|
450 |
+
"id": "7b244846-2ebf-4019-a577-3ef07e350f7c",
|
451 |
+
"metadata": {},
|
452 |
+
"source": [
|
453 |
+
"## Creating the model"
|
454 |
+
]
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "code",
|
458 |
+
"execution_count": 11,
|
459 |
+
"id": "49c1ac7c-6400-4a67-b32b-5bdc7330d790",
|
460 |
+
"metadata": {},
|
461 |
+
"outputs": [],
|
462 |
+
"source": [
|
463 |
+
"# the AutoModelForSeq2SeqLM class and use the model_ckpt variable)\n",
|
464 |
+
"model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)\n",
|
465 |
+
"\n",
|
466 |
+
"clear_output()"
|
467 |
+
]
|
468 |
+
},
|
469 |
+
{
|
470 |
+
"cell_type": "code",
|
471 |
+
"execution_count": 12,
|
472 |
+
"id": "e027b290-c04f-4241-b238-41787f32abe0",
|
473 |
+
"metadata": {},
|
474 |
+
"outputs": [],
|
475 |
+
"source": [
|
476 |
+
"# we want to ignore tokenizer pad token in the loss\n",
|
477 |
+
"label_pad_token_id = -100\n",
|
478 |
+
"\n",
|
479 |
+
"# Data Collator, specifying the tokenizer, model, and label_pad_token_id\n",
|
480 |
+
"# pad_to_multiple_of=8 to speed up training\n",
|
481 |
+
"data_collator = transformers.DataCollatorForSeq2Seq(\n",
|
482 |
+
" tokenizer,\n",
|
483 |
+
" model=model,\n",
|
484 |
+
" label_pad_token_id=label_pad_token_id,\n",
|
485 |
+
" pad_to_multiple_of=8\n",
|
486 |
+
")"
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "code",
|
491 |
+
"execution_count": 13,
|
492 |
+
"id": "0d20ee86-ac8c-4ae7-9e7c-92283e879e00",
|
493 |
+
"metadata": {},
|
494 |
+
"outputs": [
|
495 |
+
{
|
496 |
+
"name": "stderr",
|
497 |
+
"output_type": "stream",
|
498 |
+
"text": [
|
499 |
+
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
|
500 |
+
]
|
501 |
+
}
|
502 |
+
],
|
503 |
+
"source": [
|
504 |
+
"import logging\n",
|
505 |
+
"logging.getLogger(\"transformers\").setLevel(logging.WARNING)\n",
|
506 |
+
"\n",
|
507 |
+
"\n",
|
508 |
+
"# Define training hyperparameters in Seq2SeqTrainingArguments\n",
|
509 |
+
"training_args = Seq2SeqTrainingArguments(\n",
|
510 |
+
" output_dir=\"./t5_samsum\", # the output directory\n",
|
511 |
+
" logging_strategy=\"epoch\",\n",
|
512 |
+
" save_strategy=\"epoch\",\n",
|
513 |
+
" evaluation_strategy=\"epoch\",\n",
|
514 |
+
" learning_rate=2e-5,\n",
|
515 |
+
" num_train_epochs=5,\n",
|
516 |
+
" predict_with_generate=True,\n",
|
517 |
+
" per_device_train_batch_size=8,\n",
|
518 |
+
" per_device_eval_batch_size=8,\n",
|
519 |
+
" weight_decay=0.01,\n",
|
520 |
+
" load_best_model_at_end=True,\n",
|
521 |
+
" logging_steps=50,\n",
|
522 |
+
" logging_first_step=False,\n",
|
523 |
+
" fp16=False\n",
|
524 |
+
")\n",
|
525 |
+
"\n",
|
526 |
+
"# index into the tokenized_dataset variable to get the training and validation data\n",
|
527 |
+
"training_data = tokenized_dataset['train']\n",
|
528 |
+
"eval_data = tokenized_dataset['validation']\n",
|
529 |
+
"\n",
|
530 |
+
"# Create the Trainer for the model\n",
|
531 |
+
"trainer = Seq2SeqTrainer(\n",
|
532 |
+
" model=model, # the model to be trained\n",
|
533 |
+
" args=training_args, # training arguments\n",
|
534 |
+
" train_dataset=training_data, # the training dataset\n",
|
535 |
+
" eval_dataset=eval_data, # the validation dataset\n",
|
536 |
+
" tokenizer=tokenizer, # the tokenizer we used to tokenize our data\n",
|
537 |
+
" compute_metrics=compute_metrics, # the function we defined above to compute metrics\n",
|
538 |
+
" data_collator=data_collator # the data collator we defined above\n",
|
539 |
+
")"
|
540 |
+
]
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"cell_type": "code",
|
544 |
+
"execution_count": 14,
|
545 |
+
"id": "a3b5f21d-b4cb-4f8b-a7fc-cf132ef43c65",
|
546 |
+
"metadata": {},
|
547 |
+
"outputs": [
|
548 |
+
{
|
549 |
+
"name": "stdout",
|
550 |
+
"output_type": "stream",
|
551 |
+
"text": [
|
552 |
+
"TrainOutput(global_step=9210, training_loss=1.9861197174436753, metrics={'train_runtime': 3551.1547, 'train_samples_per_second': 20.743, 'train_steps_per_second': 2.594, 'total_flos': 9969277096427520.0, 'train_loss': 1.9861197174436753, 'epoch': 5.0})\n"
|
553 |
+
]
|
554 |
+
}
|
555 |
+
],
|
556 |
+
"source": [
|
557 |
+
"# Train the model (this will take a while!)\n",
|
558 |
+
"results = trainer.train()\n",
|
559 |
+
"clear_output()\n",
|
560 |
+
"pprint(results)"
|
561 |
+
]
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"cell_type": "markdown",
|
565 |
+
"id": "ddf8c308",
|
566 |
+
"metadata": {},
|
567 |
+
"source": [
|
568 |
+
"## Evaluating the model"
|
569 |
+
]
|
570 |
+
},
|
571 |
+
{
|
572 |
+
"cell_type": "code",
|
573 |
+
"execution_count": 15,
|
574 |
+
"id": "03e94a7f-2d26-48eb-ab17-cb58b14b93f3",
|
575 |
+
"metadata": {},
|
576 |
+
"outputs": [],
|
577 |
+
"source": [
|
578 |
+
"res = trainer.evaluate()\n",
|
579 |
+
"clear_output()"
|
580 |
+
]
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"cell_type": "code",
|
584 |
+
"execution_count": 18,
|
585 |
+
"id": "23675ccb-071c-4a4f-8e42-1a71dc628a5c",
|
586 |
+
"metadata": {},
|
587 |
+
"outputs": [
|
588 |
+
{
|
589 |
+
"data": {
|
590 |
+
"text/html": [
|
591 |
+
"<div>\n",
|
592 |
+
"<style scoped>\n",
|
593 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
594 |
+
" vertical-align: middle;\n",
|
595 |
+
" }\n",
|
596 |
+
"\n",
|
597 |
+
" .dataframe tbody tr th {\n",
|
598 |
+
" vertical-align: top;\n",
|
599 |
+
" }\n",
|
600 |
+
"\n",
|
601 |
+
" .dataframe thead th {\n",
|
602 |
+
" text-align: right;\n",
|
603 |
+
" }\n",
|
604 |
+
"</style>\n",
|
605 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
606 |
+
" <thead>\n",
|
607 |
+
" <tr style=\"text-align: right;\">\n",
|
608 |
+
" <th></th>\n",
|
609 |
+
" <th>eval_loss</th>\n",
|
610 |
+
" <th>eval_rouge1</th>\n",
|
611 |
+
" <th>eval_rouge2</th>\n",
|
612 |
+
" <th>eval_rougeL</th>\n",
|
613 |
+
" <th>eval_rougeLsum</th>\n",
|
614 |
+
" </tr>\n",
|
615 |
+
" </thead>\n",
|
616 |
+
" <tbody>\n",
|
617 |
+
" <tr>\n",
|
618 |
+
" <th>t5-small</th>\n",
|
619 |
+
" <td>1.764253</td>\n",
|
620 |
+
" <td>100.0</td>\n",
|
621 |
+
" <td>100.0</td>\n",
|
622 |
+
" <td>100.0</td>\n",
|
623 |
+
" <td>100.0</td>\n",
|
624 |
+
" </tr>\n",
|
625 |
+
" </tbody>\n",
|
626 |
+
"</table>\n",
|
627 |
+
"</div>"
|
628 |
+
],
|
629 |
+
"text/plain": [
|
630 |
+
" eval_loss eval_rouge1 eval_rouge2 eval_rougeL eval_rougeLsum\n",
|
631 |
+
"t5-small 1.764253 100.0 100.0 100.0 100.0"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
"execution_count": 18,
|
635 |
+
"metadata": {},
|
636 |
+
"output_type": "execute_result"
|
637 |
+
}
|
638 |
+
],
|
639 |
+
"source": [
|
640 |
+
"cols = [\"eval_loss\", \"eval_rouge1\", \"eval_rouge2\", \"eval_rougeL\", \"eval_rougeLsum\"]\n",
|
641 |
+
"filtered_scores = dict((x , res[x]) for x in cols)\n",
|
642 |
+
"pd.DataFrame([filtered_scores], index=[model_ckpt])"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "code",
|
647 |
+
"execution_count": 20,
|
648 |
+
"id": "7c59a731",
|
649 |
+
"metadata": {},
|
650 |
+
"outputs": [],
|
651 |
+
"source": [
|
652 |
+
"from transformers import pipeline\n",
|
653 |
+
"\n",
|
654 |
+
"summarizer_pipeline = pipeline(\"summarization\",\n",
|
655 |
+
" model=model,\n",
|
656 |
+
" tokenizer=tokenizer,\n",
|
657 |
+
" device=0)"
|
658 |
+
]
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"cell_type": "code",
|
662 |
+
"execution_count": 22,
|
663 |
+
"id": "5138f2bc",
|
664 |
+
"metadata": {},
|
665 |
+
"outputs": [
|
666 |
+
{
|
667 |
+
"name": "stdout",
|
668 |
+
"output_type": "stream",
|
669 |
+
"text": [
|
670 |
+
"Dialogue: Adelina: Hi handsome. Where you you come from?\r\n",
|
671 |
+
"Cyprien: What do you mean?\r\n",
|
672 |
+
"Adelina: What do you mean, \"what do you mean\"? It's a simple question, where do you come from?\r\n",
|
673 |
+
"Cyprien: Well I was born in Jarrow, live in London now, so you could say I came from either of those places\r\n",
|
674 |
+
"Cyprien: I was educated in Loughborouogh, so in a sense I came from there.\r\n",
|
675 |
+
"Adelina: OK. \r\n",
|
676 |
+
"Cyprien: In another sense I come from my mother's vagina, but I dare say everyone can say that.\r\n",
|
677 |
+
"Adelina: Are you all right?\r\n",
|
678 |
+
"Cyprien: IN another sense I come from the atoms in the air that I breath or the food I eat, which comes to me from many places, so all I can say is \"I come from Planet Earth\".\r\n",
|
679 |
+
"Adelina: OK, bye. If you're gonna be a dick...\r\n",
|
680 |
+
"Cyprien: Wait, what you got against earthlings?\n",
|
681 |
+
"-------------------------\n",
|
682 |
+
"True Summary: Cyprien irritates Adelina by giving too many responses.\n",
|
683 |
+
"-------------------------\n",
|
684 |
+
"Model Summary: Cyprien came from Jarrow, live in London. She came from Loughborouogh, and came from her mother's vagina.\n",
|
685 |
+
"-------------------------\n"
|
686 |
+
]
|
687 |
+
}
|
688 |
+
],
|
689 |
+
"source": [
|
690 |
+
"rand_idx = np.random.randint(low=0, high=len(samsum[\"test\"]))\n",
|
691 |
+
"sample = samsum[\"test\"][rand_idx]\n",
|
692 |
+
"\n",
|
693 |
+
"dialog = sample[\"dialogue\"]\n",
|
694 |
+
"true_summary = sample[\"summary\"]\n",
|
695 |
+
"\n",
|
696 |
+
"model_summary = summarizer_pipeline(dialog)\n",
|
697 |
+
"clear_output()\n",
|
698 |
+
"\n",
|
699 |
+
"print(f\"Dialogue: {dialog}\")\n",
|
700 |
+
"print(\"-\"*25)\n",
|
701 |
+
"print(f\"True Summary: {true_summary}\")\n",
|
702 |
+
"print(\"-\"*25)\n",
|
703 |
+
"print(f\"Model Summary: {model_summary[0]['summary_text']}\")\n",
|
704 |
+
"print(\"-\"*25)"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"cell_type": "code",
|
709 |
+
"execution_count": 24,
|
710 |
+
"id": "f051655f",
|
711 |
+
"metadata": {},
|
712 |
+
"outputs": [
|
713 |
+
{
|
714 |
+
"name": "stderr",
|
715 |
+
"output_type": "stream",
|
716 |
+
"text": [
|
717 |
+
"Your max_length is set to 200, but you input_length is only 94. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=47)\n"
|
718 |
+
]
|
719 |
+
},
|
720 |
+
{
|
721 |
+
"name": "stdout",
|
722 |
+
"output_type": "stream",
|
723 |
+
"text": [
|
724 |
+
"Original Text:\n",
|
725 |
+
"\n",
|
726 |
+
"Andy: I need you to come in to work on the weekend.\n",
|
727 |
+
"David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n",
|
728 |
+
"Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n",
|
729 |
+
"David: But I already made plans and this is news to me on very short notice.\n",
|
730 |
+
"Andy: Be there or you'r fired\n",
|
731 |
+
"\n",
|
732 |
+
"\n",
|
733 |
+
" -------------------------------------------------- \n",
|
734 |
+
"\n",
|
735 |
+
"Generated Summary: \n",
|
736 |
+
"[{'summary_text': 'David has plans to go on a concert. Andy needs to get his paperwork all sorted out for this year. David already made plans.'}]\n"
|
737 |
+
]
|
738 |
+
}
|
739 |
+
],
|
740 |
+
"source": [
|
741 |
+
"def create_summary(input_text, model_pipeline=summarizer_pipeline):\n",
|
742 |
+
" summary = model_pipeline(input_text)\n",
|
743 |
+
" return summary\n",
|
744 |
+
"\n",
|
745 |
+
"text = '''\n",
|
746 |
+
"Andy: I need you to come in to work on the weekend.\n",
|
747 |
+
"David: Why boss? I have plans to go on a concert I might not be able to come on the weekend.\n",
|
748 |
+
"Andy: It's important we need to get our paperwork all sorted out for this year. Corporate needs it.\n",
|
749 |
+
"David: But I already made plans and this is news to me on very short notice.\n",
|
750 |
+
"Andy: Be there or you'r fired\n",
|
751 |
+
"'''\n",
|
752 |
+
"\n",
|
753 |
+
"print(f\"Original Text:\\n{text}\")\n",
|
754 |
+
"print('\\n', '-'*50, '\\n')\n",
|
755 |
+
"\n",
|
756 |
+
"summary = create_summary(text)\n",
|
757 |
+
"\n",
|
758 |
+
"print(f\"Generated Summary: \\n{summary}\")"
|
759 |
+
]
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"cell_type": "code",
|
763 |
+
"execution_count": null,
|
764 |
+
"id": "ad5d29a0",
|
765 |
+
"metadata": {},
|
766 |
+
"outputs": [],
|
767 |
+
"source": []
|
768 |
+
}
|
769 |
+
],
|
770 |
+
"metadata": {
|
771 |
+
"kernelspec": {
|
772 |
+
"display_name": "Python 3 (ipykernel)",
|
773 |
+
"language": "python",
|
774 |
+
"name": "python3"
|
775 |
+
},
|
776 |
+
"language_info": {
|
777 |
+
"codemirror_mode": {
|
778 |
+
"name": "ipython",
|
779 |
+
"version": 3
|
780 |
+
},
|
781 |
+
"file_extension": ".py",
|
782 |
+
"mimetype": "text/x-python",
|
783 |
+
"name": "python",
|
784 |
+
"nbconvert_exporter": "python",
|
785 |
+
"pygments_lexer": "ipython3",
|
786 |
+
"version": "3.9.0"
|
787 |
+
}
|
788 |
+
},
|
789 |
+
"nbformat": 4,
|
790 |
+
"nbformat_minor": 5
|
791 |
+
}
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
from tf_model_api.model_api import ModelAPI
|
3 |
+
|
4 |
+
app = Flask(__name__)
|
5 |
+
# Create the model class object
|
6 |
+
summarizer_model = ModelAPI()
|
7 |
+
|
8 |
+
@app.route('/')
|
9 |
+
def index():
|
10 |
+
data = {
|
11 |
+
'prompts': ''
|
12 |
+
}
|
13 |
+
return render_template('index.html', data=data)
|
14 |
+
|
15 |
+
@app.route('/create-summary', methods=['POST'])
|
16 |
+
def creat_summary_response():
|
17 |
+
"""
|
18 |
+
create a summary using the input received
|
19 |
+
from the user.
|
20 |
+
"""
|
21 |
+
|
22 |
+
data = request.get_json() # Extract the JSON data from the request
|
23 |
+
text = data.get('text') # Get the 'text' field from the JSON data
|
24 |
+
|
25 |
+
summary = summarizer_model.get_summary(text)
|
26 |
+
if summary:
|
27 |
+
result = {
|
28 |
+
'status': 'success',
|
29 |
+
'result': summary}
|
30 |
+
return jsonify(result), 200
|
31 |
+
else:
|
32 |
+
result = {
|
33 |
+
'status': 'fail'
|
34 |
+
}
|
35 |
+
return jsonify(result), 400
|
36 |
+
|
37 |
+
if __name__ == '__main__':
|
38 |
+
app.run()
|
requirements.txt
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.0.0
|
2 |
+
accelerate==0.23.0
|
3 |
+
aiohttp==3.8.5
|
4 |
+
aiosignal==1.3.1
|
5 |
+
anyio==4.0.0
|
6 |
+
appnope==0.1.3
|
7 |
+
argon2-cffi==23.1.0
|
8 |
+
argon2-cffi-bindings==21.2.0
|
9 |
+
arrow==1.2.3
|
10 |
+
asttokens==2.4.0
|
11 |
+
async-lru==2.0.4
|
12 |
+
async-timeout==4.0.3
|
13 |
+
attrs==23.1.0
|
14 |
+
Babel==2.12.1
|
15 |
+
backcall==0.2.0
|
16 |
+
beautifulsoup4==4.12.2
|
17 |
+
bleach==6.0.0
|
18 |
+
blinker==1.6.2
|
19 |
+
Brotli==1.1.0
|
20 |
+
certifi==2023.7.22
|
21 |
+
cffi==1.15.1
|
22 |
+
charset-normalizer==3.2.0
|
23 |
+
click==8.1.7
|
24 |
+
comm==0.1.4
|
25 |
+
contourpy==1.1.1
|
26 |
+
cycler==0.11.0
|
27 |
+
datasets==2.14.5
|
28 |
+
debugpy==1.8.0
|
29 |
+
decorator==5.1.1
|
30 |
+
defusedxml==0.7.1
|
31 |
+
dill==0.3.7
|
32 |
+
evaluate==0.4.0
|
33 |
+
executing==1.2.0
|
34 |
+
fastjsonschema==2.18.0
|
35 |
+
filelock==3.12.4
|
36 |
+
Flask==2.3.3
|
37 |
+
fonttools==4.42.1
|
38 |
+
fqdn==1.5.1
|
39 |
+
frozenlist==1.4.0
|
40 |
+
fsspec==2023.6.0
|
41 |
+
huggingface-hub==0.17.3
|
42 |
+
idna==3.4
|
43 |
+
inflate64==0.3.1
|
44 |
+
ipykernel==6.25.2
|
45 |
+
ipython==8.15.0
|
46 |
+
ipython-genutils==0.2.0
|
47 |
+
ipywidgets==8.1.1
|
48 |
+
isoduration==20.11.0
|
49 |
+
itsdangerous==2.1.2
|
50 |
+
jedi==0.19.0
|
51 |
+
Jinja2==3.1.2
|
52 |
+
joblib==1.3.2
|
53 |
+
json5==0.9.14
|
54 |
+
jsonpointer==2.4
|
55 |
+
jsonschema==4.19.1
|
56 |
+
jsonschema-specifications==2023.7.1
|
57 |
+
jupyter==1.0.0
|
58 |
+
jupyter-console==6.6.3
|
59 |
+
jupyter-events==0.7.0
|
60 |
+
jupyter-lsp==2.2.0
|
61 |
+
jupyter_client==8.3.1
|
62 |
+
jupyter_core==5.3.2
|
63 |
+
jupyter_server==2.7.3
|
64 |
+
jupyter_server_terminals==0.4.4
|
65 |
+
jupyterlab==4.0.6
|
66 |
+
jupyterlab-pygments==0.2.2
|
67 |
+
jupyterlab-widgets==3.0.9
|
68 |
+
jupyterlab_server==2.25.0
|
69 |
+
kiwisolver==1.4.5
|
70 |
+
MarkupSafe==2.1.3
|
71 |
+
matplotlib==3.8.0
|
72 |
+
matplotlib-inline==0.1.6
|
73 |
+
mistune==3.0.1
|
74 |
+
mpmath==1.3.0
|
75 |
+
multidict==6.0.4
|
76 |
+
multiprocess==0.70.15
|
77 |
+
multivolumefile==0.2.3
|
78 |
+
nbclient==0.8.0
|
79 |
+
nbconvert==7.8.0
|
80 |
+
nbformat==5.9.2
|
81 |
+
nest-asyncio==1.5.8
|
82 |
+
networkx==3.1
|
83 |
+
nltk==3.8.1
|
84 |
+
notebook==7.0.4
|
85 |
+
notebook_shim==0.2.3
|
86 |
+
numpy==1.26.0
|
87 |
+
overrides==7.4.0
|
88 |
+
packaging==23.1
|
89 |
+
pandas==2.1.1
|
90 |
+
pandocfilters==1.5.0
|
91 |
+
parso==0.8.3
|
92 |
+
pexpect==4.8.0
|
93 |
+
pickleshare==0.7.5
|
94 |
+
Pillow==10.0.1
|
95 |
+
platformdirs==3.10.0
|
96 |
+
prometheus-client==0.17.1
|
97 |
+
prompt-toolkit==3.0.39
|
98 |
+
psutil==5.9.5
|
99 |
+
ptyprocess==0.7.0
|
100 |
+
pure-eval==0.2.2
|
101 |
+
py7zr==0.20.6
|
102 |
+
pyarrow==13.0.0
|
103 |
+
pybcj==1.0.1
|
104 |
+
pycparser==2.21
|
105 |
+
pycryptodomex==3.19.0
|
106 |
+
Pygments==2.16.1
|
107 |
+
pyparsing==3.1.1
|
108 |
+
pyppmd==1.0.0
|
109 |
+
python-dateutil==2.8.2
|
110 |
+
python-json-logger==2.0.7
|
111 |
+
pytz==2023.3.post1
|
112 |
+
PyYAML==6.0.1
|
113 |
+
pyzmq==25.1.1
|
114 |
+
pyzstd==0.15.9
|
115 |
+
qtconsole==5.4.4
|
116 |
+
QtPy==2.4.0
|
117 |
+
referencing==0.30.2
|
118 |
+
regex==2023.8.8
|
119 |
+
requests==2.31.0
|
120 |
+
responses==0.18.0
|
121 |
+
rfc3339-validator==0.1.4
|
122 |
+
rfc3986-validator==0.1.1
|
123 |
+
rouge-score==0.1.2
|
124 |
+
rpds-py==0.10.3
|
125 |
+
safetensors==0.3.3
|
126 |
+
seaborn==0.12.2
|
127 |
+
Send2Trash==1.8.2
|
128 |
+
six==1.16.0
|
129 |
+
sniffio==1.3.0
|
130 |
+
soupsieve==2.5
|
131 |
+
stack-data==0.6.2
|
132 |
+
sympy==1.12
|
133 |
+
terminado==0.17.1
|
134 |
+
texttable==1.6.7
|
135 |
+
tinycss2==1.2.1
|
136 |
+
tokenizers==0.13.3
|
137 |
+
torch==2.0.1
|
138 |
+
torchaudio==2.0.2
|
139 |
+
torchvision==0.15.2
|
140 |
+
tornado==6.3.3
|
141 |
+
tqdm==4.66.1
|
142 |
+
traitlets==5.10.1
|
143 |
+
transformers==4.33.3
|
144 |
+
typing_extensions==4.8.0
|
145 |
+
tzdata==2023.3
|
146 |
+
uri-template==1.3.0
|
147 |
+
urllib3==2.0.5
|
148 |
+
wcwidth==0.2.6
|
149 |
+
webcolors==1.13
|
150 |
+
webencodings==0.5.1
|
151 |
+
websocket-client==1.6.3
|
152 |
+
Werkzeug==2.3.7
|
153 |
+
widgetsnbextension==4.0.9
|
154 |
+
xxhash==3.3.0
|
155 |
+
yarl==1.9.2
|