Spaces:
Runtime error
Runtime error
chicham
commited on
Modify the way the results a shown (#4)
Browse files
app.py
CHANGED
@@ -2,8 +2,6 @@
|
|
2 |
from __future__ import annotations
|
3 |
|
4 |
import functools
|
5 |
-
from collections import defaultdict
|
6 |
-
from itertools import chain
|
7 |
from typing import Any
|
8 |
from typing import Callable
|
9 |
from typing import Mapping
|
@@ -13,7 +11,6 @@ import attr
|
|
13 |
import environ
|
14 |
import fasttext # not working with python3.9
|
15 |
import gradio as gr
|
16 |
-
from tokenizers.pre_tokenizers import Whitespace
|
17 |
from transformers.pipelines import pipeline
|
18 |
from transformers.pipelines.base import Pipeline
|
19 |
from transformers.pipelines.token_classification import AggregationStrategy
|
@@ -127,8 +124,8 @@ def predict(
|
|
127 |
supported_languages: tuple[str, ...] = ("fr", "de"),
|
128 |
) -> tuple[
|
129 |
Mapping[str, float],
|
130 |
-
str,
|
131 |
Mapping[str, float],
|
|
|
132 |
Sequence[tuple[str, str | None]],
|
133 |
Sequence[tuple[str, str | None]],
|
134 |
]:
|
@@ -189,27 +186,23 @@ def predict(
|
|
189 |
predict_fn: Callable,
|
190 |
query: str,
|
191 |
) -> Sequence[tuple[str, str | None]]:
|
192 |
-
def get_entity(pred: Mapping[str, str]):
|
193 |
-
return pred.get("entity", pred.get("entity_group", None))
|
194 |
|
195 |
-
|
196 |
-
mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
return res
|
206 |
|
207 |
languages = predict_lang(query)
|
208 |
translation = translate_query(query, languages)
|
209 |
classifications = classify_query(translation, categories)
|
210 |
general_entities = extract_entities(models.ner, query)
|
211 |
recipe_entities = extract_entities(models.recipe, translation)
|
212 |
-
return languages,
|
213 |
|
214 |
|
215 |
def main():
|
@@ -254,7 +247,7 @@ def main():
|
|
254 |
load_fn=lambda: pipeline(
|
255 |
"ner",
|
256 |
model=cfg.ner.general,
|
257 |
-
aggregation_strategy=AggregationStrategy.
|
258 |
),
|
259 |
),
|
260 |
recipe=Predictor(
|
@@ -282,15 +275,15 @@ def main():
|
|
282 |
type="auto",
|
283 |
label="Language identification",
|
284 |
),
|
285 |
-
gr.outputs.Textbox(
|
286 |
-
label="English query",
|
287 |
-
type="auto",
|
288 |
-
),
|
289 |
gr.outputs.Label(
|
290 |
num_top_classes=cfg.classification.max_results,
|
291 |
type="auto",
|
292 |
label="Predicted categories",
|
293 |
),
|
|
|
|
|
|
|
|
|
294 |
gr.outputs.HighlightedText(label="NER generic"),
|
295 |
gr.outputs.HighlightedText(label="NER Recipes"),
|
296 |
],
|
|
|
2 |
from __future__ import annotations
|
3 |
|
4 |
import functools
|
|
|
|
|
5 |
from typing import Any
|
6 |
from typing import Callable
|
7 |
from typing import Mapping
|
|
|
11 |
import environ
|
12 |
import fasttext # not working with python3.9
|
13 |
import gradio as gr
|
|
|
14 |
from transformers.pipelines import pipeline
|
15 |
from transformers.pipelines.base import Pipeline
|
16 |
from transformers.pipelines.token_classification import AggregationStrategy
|
|
|
124 |
supported_languages: tuple[str, ...] = ("fr", "de"),
|
125 |
) -> tuple[
|
126 |
Mapping[str, float],
|
|
|
127 |
Mapping[str, float],
|
128 |
+
str,
|
129 |
Sequence[tuple[str, str | None]],
|
130 |
Sequence[tuple[str, str | None]],
|
131 |
]:
|
|
|
186 |
predict_fn: Callable,
|
187 |
query: str,
|
188 |
) -> Sequence[tuple[str, str | None]]:
|
|
|
|
|
189 |
|
190 |
+
predictions = predict_fn(query)
|
|
|
191 |
|
192 |
+
if len(predictions) == 0:
|
193 |
+
return [(query, None)]
|
194 |
+
else:
|
195 |
+
return [
|
196 |
+
(pred["word"], pred.get("entity_group", pred.get("entity", None)))
|
197 |
+
for pred in predictions
|
198 |
+
]
|
|
|
199 |
|
200 |
languages = predict_lang(query)
|
201 |
translation = translate_query(query, languages)
|
202 |
classifications = classify_query(translation, categories)
|
203 |
general_entities = extract_entities(models.ner, query)
|
204 |
recipe_entities = extract_entities(models.recipe, translation)
|
205 |
+
return languages, classifications, translation, general_entities, recipe_entities
|
206 |
|
207 |
|
208 |
def main():
|
|
|
247 |
load_fn=lambda: pipeline(
|
248 |
"ner",
|
249 |
model=cfg.ner.general,
|
250 |
+
aggregation_strategy=AggregationStrategy.SIMPLE,
|
251 |
),
|
252 |
),
|
253 |
recipe=Predictor(
|
|
|
275 |
type="auto",
|
276 |
label="Language identification",
|
277 |
),
|
|
|
|
|
|
|
|
|
278 |
gr.outputs.Label(
|
279 |
num_top_classes=cfg.classification.max_results,
|
280 |
type="auto",
|
281 |
label="Predicted categories",
|
282 |
),
|
283 |
+
gr.outputs.Textbox(
|
284 |
+
label="English query",
|
285 |
+
type="auto",
|
286 |
+
),
|
287 |
gr.outputs.HighlightedText(label="NER generic"),
|
288 |
gr.outputs.HighlightedText(label="NER Recipes"),
|
289 |
],
|