krotima1 commited on
Commit
33705cf
1 Parent(s): 821ca2c

add summarizer example

Browse files
Files changed (1) hide show
  1. Summarizer.ipynb +237 -0
Summarizer.ipynb ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "6f213e31",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Summarizer\n",
9
+ "This script is used for summarizing Czech news texts as well as for generating news headlines or abstracts. It can be considered as a demonstration for the application of our summarization models.\n"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "id": "5b4d4b6e",
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "#dependencies\n",
20
+ "import torch as pt\n",
21
+ "import numpy as np\n",
22
+ "\n",
23
+ "from collections import OrderedDict\n",
24
+ "\n",
25
+ "from transformers import AutoModelForSeq2SeqLM\n",
26
+ "from transformers import AutoTokenizer\n",
27
+ "\n",
28
+ "from sentence_splitter import SentenceSplitter, split_text_into_sentences\n",
29
+ "\n",
30
+ "#init Summarizer\n",
31
+ "#comment cuda and delete .to(cuda) if using cpuUse\n",
32
+ "class Summarizer:\n",
33
+ " def __init__(self,model, tokenizer, inference_cfg):\n",
34
+ " self.model = model\n",
35
+ " self.model.cuda()\n",
36
+ " self.tokenizer = tokenizer\n",
37
+ " self.inference_cfg = inference_cfg\n",
38
+ " self.enc_max_len = 512\n",
39
+ " \n",
40
+ " #tokenize & summarize input texts\n",
41
+ " def __call__(self, texts, inference_cfg=None):\n",
42
+ " if type(texts) == str:\n",
43
+ " texts = [texts]\n",
44
+ " assert type(texts) == list and type(texts[0]) == str, \"Expected string or list of strings\"\n",
45
+ " summaries = []\n",
46
+ " self.inference_cfg = inference_cfg if inference_cfg is not None else self.inference_cfg\n",
47
+ " for text in texts:\n",
48
+ " text = self.tokenizer.eos_token.join(SentenceSplitter(language='cs').split(text))\n",
49
+ " ttext = self.tokenizer(text,max_length = self.enc_max_len, truncation=True, padding=\"max_length\",return_tensors=\"pt\")\n",
50
+ " summaries.append(self._summarize(ttext,**self.inference_cfg)[0])\n",
51
+ " return summaries\n",
52
+ " \n",
53
+ " #summarize batch of data\n",
54
+ " def _summarize(self, data, num_beams=1, do_sample=False, \n",
55
+ " top_k=50, \n",
56
+ " top_p=1.0,\n",
57
+ " temperature=1.0,\n",
58
+ " repetition_penalty=1.0,\n",
59
+ " no_repeat_ngram_size = None,\n",
60
+ " max_length=1024,\n",
61
+ " min_length=10,\n",
62
+ " decode_decoder_ids = False,\n",
63
+ " early_stopping = False,**kwargs):\n",
64
+ " summary = model.generate(input_ids=data[\"input_ids\"].to(\"cuda\"),attention_mask=data[\"attention_mask\"].to(\"cuda\"),\n",
65
+ " num_beams= num_beams,\n",
66
+ " do_sample= do_sample,\n",
67
+ " top_k=top_k,\n",
68
+ " top_p=top_p,\n",
69
+ " temperature=temperature,\n",
70
+ " repetition_penalty=repetition_penalty,\n",
71
+ " max_length=max_length,\n",
72
+ " min_length=min_length,\n",
73
+ " early_stopping=early_stopping,\n",
74
+ " forced_bos_token_id=tokenizer.lang_code_to_id['cs_CZ'])\n",
75
+ " return self.tokenizer.batch_decode(summary,skip_special_tokens=True)\n",
76
+ "\n",
77
+ "\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "id": "e3b06d20",
83
+ "metadata": {},
84
+ "source": [
85
+ "# Use\n",
86
+ "- Load Czech summarization model from https://huggingface.co/krotima1\n",
87
+ "- Summarize Czech news texts\n",
88
+ "- Play with summarization parameters"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "0e52c368",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "data": {
99
+ "application/vnd.jupyter.widget-view+json": {
100
+ "model_id": "915f91449a8a43458945118791b2654a",
101
+ "version_major": 2,
102
+ "version_minor": 0
103
+ },
104
+ "text/plain": [
105
+ "Downloading: 0%| | 0.00/480 [00:00<?, ?B/s]"
106
+ ]
107
+ },
108
+ "metadata": {},
109
+ "output_type": "display_data"
110
+ },
111
+ {
112
+ "data": {
113
+ "application/vnd.jupyter.widget-view+json": {
114
+ "model_id": "432868a8c6164fe8b681d999a5fef236",
115
+ "version_major": 2,
116
+ "version_minor": 0
117
+ },
118
+ "text/plain": [
119
+ "Downloading: 0%| | 0.00/8.66M [00:00<?, ?B/s]"
120
+ ]
121
+ },
122
+ "metadata": {},
123
+ "output_type": "display_data"
124
+ },
125
+ {
126
+ "data": {
127
+ "application/vnd.jupyter.widget-view+json": {
128
+ "model_id": "819f555f4bce4fdea3db21fe6c298975",
129
+ "version_major": 2,
130
+ "version_minor": 0
131
+ },
132
+ "text/plain": [
133
+ "Downloading: 0%| | 0.00/495 [00:00<?, ?B/s]"
134
+ ]
135
+ },
136
+ "metadata": {},
137
+ "output_type": "display_data"
138
+ }
139
+ ],
140
+ "source": [
141
+ "# Summarization config, setting up hyperparameters of inference methods used during summarization\n",
142
+ "def summ_config():\n",
143
+ " cfg = OrderedDict([\n",
144
+ " # summarization model - checkpoint from website https://huggingface.co/krotima1\n",
145
+ " (\"model_name\", \"krotima1/mbart-at2h-c\"),\n",
146
+ " \n",
147
+ " #inference configuration of summ parameters\n",
148
+ " (\"inference_cfg\", OrderedDict([\n",
149
+ " (\"num_beams\", 4),\n",
150
+ " (\"top_k\", 40),\n",
151
+ " (\"top_p\", 0.92),\n",
152
+ " (\"do_sample\", True),\n",
153
+ " (\"temperature\", 0.89),\n",
154
+ " (\"repetition_penalty\", 1.2),\n",
155
+ " (\"no_repeat_ngram_size\", None),\n",
156
+ " (\"early_stopping\", True),\n",
157
+ " # (\"max_length\", 96),\n",
158
+ " (\"min_length\", 10),\n",
159
+ " ])),\n",
160
+ " #texts to summarize\n",
161
+ " (\"text\",\n",
162
+ " [\n",
163
+ " \"Input your Czech text\",\n",
164
+ " ]\n",
165
+ " ),\n",
166
+ " ])\n",
167
+ " return cfg\n",
168
+ "cfg = summ_config()\n",
169
+ "#load model & tokenizer\n",
170
+ "model = AutoModelForSeq2SeqLM.from_pretrained(cfg[\"model_name\"])\n",
171
+ "tokenizer = AutoTokenizer.from_pretrained(cfg[\"model_name\"])\n",
172
+ "#init summarizer\n",
173
+ "summarize = Summarizer(model, tokenizer, cfg[\"inference_cfg\"])\n",
174
+ "#summarize Czech texts - jdem na to\n",
175
+ "#cfg[\"text\"]=...\n",
176
+ "summarize(cfg[\"text\"])"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "id": "d89f8d1e",
182
+ "metadata": {},
183
+ "source": [
184
+ "#### Change text to summarize"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "63103bac",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "cfg[\"text\"] = \"Nováčci, které stvořila soutěž Česko hledá SuperStar, vstoupí do konkurence s osvědčenými jmény domácího trhu. Jedním z nich je Lucie Bílá, která vydá kolekci hitů nazvanou Láska je láska. Naopak její kolegyně Anna K. bude mít zbrusu nové písně. Styl alba Anna K.: Noc na zemi prý připomíná rockový nářez s místy až mrazivým soundem.\\nDo společnosti hvězd se vypracoval drsný zpěvák Daniel Landa, jehož album Neofolk vyjde začátkem října. \\\"Na desce jsou písničky, které zdáli folk připomínají. Více než dříve kladu důraz na jejich texty. Neofolk mi přijde jako větší vykřičník a jistě i zlobivější název. Jinak je to bigbít i folk. Po obsahu,\\\" míní Landa, který pracuje na dvou muzikálech. Navíc v listopadu vydá záznam z vyšehradského koncertu, na kterém zpíval písně Karla Kryla.\\nKarel Plíhal si zase oblíbil básníka Josefa Kainara, jemuž vzdává hold na desce Nebe na zemi. Práce na tomto dlouho odkládaném a několikrát ohlašovaném albu se protáhly na šest let. Za tu dobu Plíhal nastudoval Kainarovu tvorbu a našel způsob, kterým by ji nejlépe vyjádřil. Stačí mu k tomu kytara, hlas a hostující zpěvačka Zuzana Navarová.\\nPetr Muk si vybral jiné oblíbence: britskou kapelu Erasure, se kterou kdysi jako vokalista kapely Oceán odehrál 27 koncertů ve Velké Británii a Irsku. Na minialbu Oh L\\\"Amour chce Muk představit některé jejich písně.\\nKoncem září se objeví novinka kapely Kryštof. Třetí album Mikrokosmos vznikalo pod dohledem Jana P. Muchowa, který si přál, aby kryštofovská muzika zněla maximálně živě, bez elektronických či jiných efektů. Hit Srdce, který se dostal do vysílání hudební televize MTV, byl vybrán jako propagační píseň.\\nŘíká se, že třetí album je rozhodující. Přesvědčí se o tom i kapela MIG 21, soustředěná kolem herce a zpěváka Jiřího Macháčka. Jejich nová deska Pop Pop Pop ukazuje směr, kterým MIG 21 poletí. Méně se ví o listopadovém albu Midlife skupiny Support Lesbiens, která letos propustila kytaristu a skladatele Yardu Helešice a buduje nový hudební profil.\\nRuku v ruce s nástupem hudebních DVD přibudou tituly bratří Ebenů, kapel Divokej Bill, Monkey Business nebo Kabát. Teplická formace se chce ohlédnout za dvěma masivními turné let 2003 a 2004, která jí zajistila postavení nejžádanější skupiny v Česku.\""
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "ce6f9474",
200
+ "metadata": {},
201
+ "source": [
202
+ "#### summarize"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "ec6bcb51",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "summarize(cfg[\"text\"])"
213
+ ]
214
+ }
215
+ ],
216
+ "metadata": {
217
+ "kernelspec": {
218
+ "display_name": "Python 3",
219
+ "language": "python",
220
+ "name": "python3"
221
+ },
222
+ "language_info": {
223
+ "codemirror_mode": {
224
+ "name": "ipython",
225
+ "version": 3
226
+ },
227
+ "file_extension": ".py",
228
+ "mimetype": "text/x-python",
229
+ "name": "python",
230
+ "nbconvert_exporter": "python",
231
+ "pygments_lexer": "ipython3",
232
+ "version": "3.6.8"
233
+ }
234
+ },
235
+ "nbformat": 4,
236
+ "nbformat_minor": 5
237
+ }