import spaces import os import json import gradio as gr import pycountry import torch from datetime import datetime from typing import Dict, Union from gliner import GLiNER _MODEL = {} _CACHE_DIR = os.environ.get("CACHE_DIR", None) THRESHOLD = 0.3 LABELS = ["country", "year", "statistical indicator", "geographic region"] QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024" MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"] print(f"Cache directory: {_CACHE_DIR}") def get_model(model_name: str = None): start = datetime.now() if model_name is None: model_name = "urchade/gliner_base" global _MODEL if _MODEL.get(model_name) is None: _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR) if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"): _MODEL[model_name] = _MODEL[model_name].to("cuda") print(f"{datetime.now()} :: get_model :: {datetime.now() - start}") return _MODEL[model_name] # Initialize model here. print("Initializing models...") for model_name in MODELS: model = get_model(model_name=model_name) model.predict_entities(QUERY, LABELS, threshold=THRESHOLD) def get_country(country_name: str): try: return pycountry.countries.search_fuzzy(country_name) except LookupError: return None @spaces.GPU(enable_queue=True, duration=5) def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False): start = datetime.now() model = get_model(model_name) if isinstance(labels, str): labels = [i.strip() for i in labels.split(",")] entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner) print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}") return entities def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]: entities = [] _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner) for entity in _entities: if entity["label"] == "country": country = get_country(entity["text"]) if country: entity["normalized"] = [dict(c) for c in country] entities.append(entity) else: entities.append(entity) payload = {"query": query, "entities": entities} print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n") return payload with gr.Blocks(title="GLiNER-query-parser") as demo: gr.Markdown( """ # GLiNER-based Query Parser (a zero-shot NER model) This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license. ## Links * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base * All GLiNER models: https://huggingface.co/models?library=gliner * Paper: https://arxiv.org/abs/2311.08526 * Repository: https://github.com/urchade/GLiNER """ ) query = gr.Textbox( value=QUERY, label="query", placeholder="Enter your query here" ) with gr.Row() as row: model_name = gr.Radio( choices=MODELS, value="urchade/gliner_base", label="Model", ) entities = gr.Textbox( value=", ".join(LABELS), label="entities", placeholder="Enter the entities to detect here (comma separated)", scale=2, ) threshold = gr.Slider( 0, 1, value=THRESHOLD, step=0.01, label="Threshold", info="Lower threshold may extract more false-positive entities from the query.", scale=1, ) is_nested = gr.Checkbox( value=False, label="Nested NER", info="Setting to True extracts nested entities", scale=0, ) output = gr.JSON(label="Extracted entities") submit_btn = gr.Button("Submit") # Submitting query.submit( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) entities.submit( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) threshold.release( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) submit_btn.click( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) is_nested.change( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) model_name.change( fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output ) demo.queue(default_concurrency_limit=5) demo.launch(debug=True) """ from gradio_client import Client client = Client("avsolatorio/query-parser") result = client.predict( query="gdp, m3, and child mortality of india and southeast asia 2024", labels="country, year, statistical indicator, region", threshold=0.3, nested_ner=False, api_name="/parse_query" ) print(result) """