chicham commited on
Commit
3ad6577
1 Parent(s): 442640c

Modify the way the results a shown (#4)

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -2,8 +2,6 @@
2
  from __future__ import annotations
3
 
4
  import functools
5
- from collections import defaultdict
6
- from itertools import chain
7
  from typing import Any
8
  from typing import Callable
9
  from typing import Mapping
@@ -13,7 +11,6 @@ import attr
13
  import environ
14
  import fasttext # not working with python3.9
15
  import gradio as gr
16
- from tokenizers.pre_tokenizers import Whitespace
17
  from transformers.pipelines import pipeline
18
  from transformers.pipelines.base import Pipeline
19
  from transformers.pipelines.token_classification import AggregationStrategy
@@ -127,8 +124,8 @@ def predict(
127
  supported_languages: tuple[str, ...] = ("fr", "de"),
128
  ) -> tuple[
129
  Mapping[str, float],
130
- str,
131
  Mapping[str, float],
 
132
  Sequence[tuple[str, str | None]],
133
  Sequence[tuple[str, str | None]],
134
  ]:
@@ -189,27 +186,23 @@ def predict(
189
  predict_fn: Callable,
190
  query: str,
191
  ) -> Sequence[tuple[str, str | None]]:
192
- def get_entity(pred: Mapping[str, str]):
193
- return pred.get("entity", pred.get("entity_group", None))
194
 
195
- mapping = defaultdict(lambda: None)
196
- mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
197
 
198
- query_processed = Whitespace().pre_tokenize_str(query)
199
- res = tuple(
200
- chain.from_iterable(
201
- ((word, mapping[word]), (" ", None)) for word, _ in query_processed
202
- ),
203
- )
204
- print(res)
205
- return res
206
 
207
  languages = predict_lang(query)
208
  translation = translate_query(query, languages)
209
  classifications = classify_query(translation, categories)
210
  general_entities = extract_entities(models.ner, query)
211
  recipe_entities = extract_entities(models.recipe, translation)
212
- return languages, translation, classifications, general_entities, recipe_entities
213
 
214
 
215
  def main():
@@ -254,7 +247,7 @@ def main():
254
  load_fn=lambda: pipeline(
255
  "ner",
256
  model=cfg.ner.general,
257
- aggregation_strategy=AggregationStrategy.MAX,
258
  ),
259
  ),
260
  recipe=Predictor(
@@ -282,15 +275,15 @@ def main():
282
  type="auto",
283
  label="Language identification",
284
  ),
285
- gr.outputs.Textbox(
286
- label="English query",
287
- type="auto",
288
- ),
289
  gr.outputs.Label(
290
  num_top_classes=cfg.classification.max_results,
291
  type="auto",
292
  label="Predicted categories",
293
  ),
 
 
 
 
294
  gr.outputs.HighlightedText(label="NER generic"),
295
  gr.outputs.HighlightedText(label="NER Recipes"),
296
  ],
2
  from __future__ import annotations
3
 
4
  import functools
 
 
5
  from typing import Any
6
  from typing import Callable
7
  from typing import Mapping
11
  import environ
12
  import fasttext # not working with python3.9
13
  import gradio as gr
 
14
  from transformers.pipelines import pipeline
15
  from transformers.pipelines.base import Pipeline
16
  from transformers.pipelines.token_classification import AggregationStrategy
124
  supported_languages: tuple[str, ...] = ("fr", "de"),
125
  ) -> tuple[
126
  Mapping[str, float],
 
127
  Mapping[str, float],
128
+ str,
129
  Sequence[tuple[str, str | None]],
130
  Sequence[tuple[str, str | None]],
131
  ]:
186
  predict_fn: Callable,
187
  query: str,
188
  ) -> Sequence[tuple[str, str | None]]:
 
 
189
 
190
+ predictions = predict_fn(query)
 
191
 
192
+ if len(predictions) == 0:
193
+ return [(query, None)]
194
+ else:
195
+ return [
196
+ (pred["word"], pred.get("entity_group", pred.get("entity", None)))
197
+ for pred in predictions
198
+ ]
 
199
 
200
  languages = predict_lang(query)
201
  translation = translate_query(query, languages)
202
  classifications = classify_query(translation, categories)
203
  general_entities = extract_entities(models.ner, query)
204
  recipe_entities = extract_entities(models.recipe, translation)
205
+ return languages, classifications, translation, general_entities, recipe_entities
206
 
207
 
208
  def main():
247
  load_fn=lambda: pipeline(
248
  "ner",
249
  model=cfg.ner.general,
250
+ aggregation_strategy=AggregationStrategy.SIMPLE,
251
  ),
252
  ),
253
  recipe=Predictor(
275
  type="auto",
276
  label="Language identification",
277
  ),
 
 
 
 
278
  gr.outputs.Label(
279
  num_top_classes=cfg.classification.max_results,
280
  type="auto",
281
  label="Predicted categories",
282
  ),
283
+ gr.outputs.Textbox(
284
+ label="English query",
285
+ type="auto",
286
+ ),
287
  gr.outputs.HighlightedText(label="NER generic"),
288
  gr.outputs.HighlightedText(label="NER Recipes"),
289
  ],