surdan commited on
Commit
38ca16f
1 Parent(s): 7a10124

Upload Inference.ipynb

Browse files
Files changed (1) hide show
  1. Inference.ipynb +248 -0
Inference.ipynb ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "80b213e0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install termcolor==1.1.0 transformers==4.18.0"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "73f81039",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from transformers import pipeline\n",
21
+ "from termcolor import colored\n",
22
+ "import torch"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "44668ca1",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "class Ner_Extractor:\n",
33
+ " \"\"\"\n",
34
+ " Labeling each token in sentence as named entity\n",
35
+ "\n",
36
+ " :param model_checkpoint: name or path to model \n",
37
+ " :type model_checkpoint: string\n",
38
+ " \"\"\"\n",
39
+ " \n",
40
+ " def __init__(self, model_checkpoint: str):\n",
41
+ " self.token_pred_pipeline = pipeline(\"token-classification\", \n",
42
+ " model=model_checkpoint, \n",
43
+ " aggregation_strategy=\"average\")\n",
44
+ " \n",
45
+ " @staticmethod\n",
46
+ " def text_color(txt, txt_c=\"blue\", txt_hglt=\"on_yellow\"):\n",
47
+ " \"\"\"\n",
48
+ " Coloring part of text \n",
49
+ " \n",
50
+ " :param txt: part of text from sentence \n",
51
+ " :type txt: string\n",
52
+ " :param txt_c: text color \n",
53
+ " :type txt_c: string \n",
54
+ " :param txt_hglt: color of text highlighting \n",
55
+ " :type txt_hglt: string\n",
56
+ " :return: string with color labeling\n",
57
+ " :rtype: string\n",
58
+ " \"\"\"\n",
59
+ " return colored(txt, txt_c, txt_hglt)\n",
60
+ " \n",
61
+ " @staticmethod\n",
62
+ " def concat_entities(ner_result):\n",
63
+ " \"\"\"\n",
64
+ " Concatenation entities from model output on grouped entities\n",
65
+ " \n",
66
+ " :param ner_result: output from model pipeline \n",
67
+ " :type ner_result: list\n",
68
+ " :return: list of grouped entities with start - end position in text\n",
69
+ " :rtype: list\n",
70
+ " \"\"\"\n",
71
+ " entities = []\n",
72
+ " prev_entity = None\n",
73
+ " prev_end = 0\n",
74
+ " for i in range(len(ner_result)):\n",
75
+ " \n",
76
+ " if (ner_result[i][\"entity_group\"] == prev_entity) &\\\n",
77
+ " (ner_result[i][\"start\"] == prev_end):\n",
78
+ " \n",
79
+ " entities[i-1][2] = ner_result[i][\"end\"]\n",
80
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
81
+ " prev_end = ner_result[i][\"end\"]\n",
82
+ " else:\n",
83
+ " entities.append([ner_result[i][\"entity_group\"], \n",
84
+ " ner_result[i][\"start\"], \n",
85
+ " ner_result[i][\"end\"]])\n",
86
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
87
+ " prev_end = ner_result[i][\"end\"]\n",
88
+ " \n",
89
+ " return entities\n",
90
+ " \n",
91
+ " \n",
92
+ " def colored_text(self, text: str, entities: list):\n",
93
+ " \"\"\"\n",
94
+ " Highlighting in the text named entities\n",
95
+ " \n",
96
+ " :param text: sentence or a part of corpus\n",
97
+ " :type text: string\n",
98
+ " :param entities: concated entities on groups with start - end position in text\n",
99
+ " :type entities: list\n",
100
+ " :return: Highlighted sentence\n",
101
+ " :rtype: string\n",
102
+ " \"\"\"\n",
103
+ " colored_text = \"\"\n",
104
+ " init_pos = 0\n",
105
+ " for ent in entities:\n",
106
+ " if ent[1] > init_pos:\n",
107
+ " colored_text += text[init_pos: ent[1]]\n",
108
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
109
+ " init_pos = ent[2]\n",
110
+ " else:\n",
111
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
112
+ " init_pos = ent[2]\n",
113
+ " \n",
114
+ " return colored_text\n",
115
+ " \n",
116
+ " \n",
117
+ " def get_entities(self, text: str):\n",
118
+ " \"\"\"\n",
119
+ " Extracting entities from text with them position in text\n",
120
+ " \n",
121
+ " :param text: input sentence for preparing\n",
122
+ " :type text: string\n",
123
+ " :return: list with entities from text\n",
124
+ " :rtype: list\n",
125
+ " \"\"\"\n",
126
+ " assert len(text) > 0, text\n",
127
+ " entities = self.token_pred_pipeline(text)\n",
128
+ " concat_ent = self.concat_entities(entities)\n",
129
+ " \n",
130
+ " return concat_ent\n",
131
+ " \n",
132
+ " \n",
133
+ " def show_ents_on_text(self, text: str):\n",
134
+ " \"\"\"\n",
135
+ " Highlighting named entities in input text \n",
136
+ " \n",
137
+ " :param text: input sentence for preparing\n",
138
+ " :type text: string\n",
139
+ " :return: Highlighting text\n",
140
+ " :rtype: string\n",
141
+ " \"\"\"\n",
142
+ " assert len(text) > 0, text\n",
143
+ " entities = self.get_entities(text)\n",
144
+ " \n",
145
+ " return self.colored_text(text, entities)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "aaa0a5bd",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "seqs_example = [\"Из Дзюбы вышел бы отличный бразилец». Интервью Клаудиньо\",\n",
156
+ " \"Самый яркий бразилец «Зенита» рассказал о встрече с Пеле\",\n",
157
+ " \"Стали известны подробности нового иска РФС к УЕФА и ФИФА\",\n",
158
+ " \"Реванш «Баварии», голы от «Реала» с «Челси»: ставим на ЛЧ\",\n",
159
+ " \"Кварацхелия не вернется в «Рубин» и станет игроком «Наполи»\",\n",
160
+ " \"«Манчестер Сити» сделал грандиозное предложение по Холанду\",\n",
161
+ " \"В России хотят возродить Кубок лиги. Он проводился в 2003 году\",\n",
162
+ " \"Экс-игрок «Реала» находится в критическом состоянии после ДТП\",\n",
163
+ " \"Аршавин посмеялся над показателями Глушакова в игре с ЦСКА\",\n",
164
+ " \"Арьен Роббен пробежал 42-километровый марафон\"\n",
165
+ " ]"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "380d9824",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "%%time\n",
176
+ "## init model for inference\n",
177
+ "extractor = Ner_Extractor(model_checkpoint = \"surdan/LaBSE_ner_nerel\")"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "37ebcf51",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "%%time\n",
188
+ "## get highlighting sentences\n",
189
+ "show_entities_in_text = (extractor.show_ents_on_text(i) for i in seqs_example)"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "e03b28c7",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "%%time\n",
200
+ "## get list of entities from sentence\n",
201
+ "l_entities = [extractor.get_entities(i) for i in seqs_example]\n",
202
+ "len(l_entities), len(seqs_example)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "a2d4ae84",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "## print highlighting sentences\n",
213
+ "for i in range(len(seqs_example)):\n",
214
+ " print(next(show_entities_in_text, \"End of generator\"))\n",
215
+ " print(\"-*-\"*25)"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "9ce3e083",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": []
225
+ }
226
+ ],
227
+ "metadata": {
228
+ "kernelspec": {
229
+ "display_name": "Python 3 (ipykernel)",
230
+ "language": "python",
231
+ "name": "python3"
232
+ },
233
+ "language_info": {
234
+ "codemirror_mode": {
235
+ "name": "ipython",
236
+ "version": 3
237
+ },
238
+ "file_extension": ".py",
239
+ "mimetype": "text/x-python",
240
+ "name": "python",
241
+ "nbconvert_exporter": "python",
242
+ "pygments_lexer": "ipython3",
243
+ "version": "3.8.10"
244
+ }
245
+ },
246
+ "nbformat": 4,
247
+ "nbformat_minor": 5
248
+ }