Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import json | |
| import gradio as gr | |
| import pycountry | |
| import torch | |
| from datetime import datetime | |
| from typing import Dict, Union | |
| from gliner import GLiNER | |
| _MODEL = {} | |
| _CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
| THRESHOLD = 0.3 | |
| LABELS = ["country", "year", "statistical indicator", "geographic region"] | |
| QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024" | |
| MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"] | |
| print(f"Cache directory: {_CACHE_DIR}") | |
| def get_model(model_name: str = None): | |
| start = datetime.now() | |
| if model_name is None: | |
| model_name = "urchade/gliner_base" | |
| global _MODEL | |
| if _MODEL.get(model_name) is None: | |
| _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR) | |
| if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"): | |
| _MODEL[model_name] = _MODEL[model_name].to("cuda") | |
| print(f"{datetime.now()} :: get_model :: {datetime.now() - start}") | |
| return _MODEL[model_name] | |
| def get_country(country_name: str): | |
| try: | |
| return pycountry.countries.search_fuzzy(country_name) | |
| except LookupError: | |
| return None | |
| def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False): | |
| start = datetime.now() | |
| model = get_model(model_name) | |
| if isinstance(labels, str): | |
| labels = [i.strip() for i in labels.split(",")] | |
| entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner) | |
| print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}") | |
| return entities | |
| def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]: | |
| entities = [] | |
| _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner) | |
| for entity in _entities: | |
| if entity["label"] == "country": | |
| country = get_country(entity["text"]) | |
| if country: | |
| entity["normalized"] = [dict(c) for c in country] | |
| entities.append(entity) | |
| else: | |
| entities.append(entity) | |
| payload = {"query": query, "entities": entities} | |
| print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n") | |
| return payload | |
| def annotate_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]: | |
| payload = parse_query(query, labels, threshold, nested_ner, model_name) | |
| return { | |
| "text": query, | |
| "entities": [ | |
| { | |
| "entity": entity["label"], | |
| "word": entity["text"], | |
| "start": entity["start"], | |
| "end": entity["end"], | |
| "score": entity["score"], | |
| } | |
| for entity in payload["entities"] | |
| ], | |
| } | |
| # Initialize model here. | |
| print("Initializing models...") | |
| for model_name in MODELS: | |
| predict_entities(model_name, QUERY, LABELS, threshold=THRESHOLD) | |
| with gr.Blocks(title="GLiNER-query-parser") as demo: | |
| gr.Markdown( | |
| """ | |
| # GLiNER-based Query Parser (a zero-shot NER model) | |
| This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license. | |
| ## Links | |
| * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base | |
| * All GLiNER models: https://huggingface.co/models?library=gliner | |
| * Paper: https://arxiv.org/abs/2311.08526 | |
| * Repository: https://github.com/urchade/GLiNER | |
| """ | |
| ) | |
| query = gr.Textbox( | |
| value=QUERY, label="query", placeholder="Enter your query here" | |
| ) | |
| with gr.Row() as row: | |
| model_name = gr.Radio( | |
| choices=MODELS, | |
| value="urchade/gliner_base", | |
| label="Model", | |
| ) | |
| entities = gr.Textbox( | |
| value=", ".join(LABELS), | |
| label="entities", | |
| placeholder="Enter the entities to detect here (comma separated)", | |
| scale=2, | |
| ) | |
| threshold = gr.Slider( | |
| 0, | |
| 1, | |
| value=THRESHOLD, | |
| step=0.01, | |
| label="Threshold", | |
| info="Lower threshold may extract more false-positive entities from the query.", | |
| scale=1, | |
| ) | |
| is_nested = gr.Checkbox( | |
| value=False, | |
| label="Nested NER", | |
| info="Setting to True extracts nested entities", | |
| scale=0, | |
| ) | |
| output = gr.HighlightedText(label="Annotated entities") | |
| submit_btn = gr.Button("Submit") | |
| json_output = gr.JSON(label="Extracted entities") | |
| json_button = gr.Button("Get JSON") | |
| # Submitting | |
| query.submit( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| entities.submit( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| threshold.release( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| submit_btn.click( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| is_nested.change( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| model_name.change( | |
| fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
| ) | |
| json_button.click( | |
| fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=json_output | |
| ) | |
| demo.queue(default_concurrency_limit=5) | |
| demo.launch(debug=True) | |
| """ | |
| from gradio_client import Client | |
| client = Client("avsolatorio/query-parser") | |
| result = client.predict( | |
| query="gdp, m3, and child mortality of india and southeast asia 2024", | |
| labels="country, year, statistical indicator, region", | |
| threshold=0.3, | |
| nested_ner=False, | |
| api_name="/parse_query" | |
| ) | |
| print(result) | |
| """ | |