avsolatorio commited on
Commit
78d1da2
1 Parent(s): 78762bd

Update example and model

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -7,14 +7,15 @@ from gliner import GLiNER
7
 
8
  _MODEL = {}
9
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
 
 
10
 
11
  print(f"Cache directory: {_CACHE_DIR}")
12
 
13
 
14
  def get_model(model_name: str = None):
15
  if model_name is None:
16
- # model_name = "urchade/gliner_base"
17
- model_name = "urchade/gliner_medium-v2.1"
18
 
19
  global _MODEL
20
 
@@ -31,7 +32,7 @@ def get_country(country_name: str):
31
  return None
32
 
33
 
34
- def parse_query(query: str, labels: Union[str, list], threshold: float = 0.5, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
35
  model = get_model(model_name)
36
 
37
  if isinstance(labels, str):
@@ -70,16 +71,16 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
70
  )
71
 
72
  query = gr.Textbox(
73
- value="gdp of the philippines in 2024", label="query", placeholder="Enter your query here"
74
  )
75
  with gr.Row() as row:
76
  model_name = gr.Radio(
77
- choices=["urchade/gliner_medium-v2.1", "urchade/gliner_base"],
78
- value="urchade/gliner_medium-v2.1",
79
  label="Model",
80
  )
81
  entities = gr.Textbox(
82
- value="country, year, indicator",
83
  label="entities",
84
  placeholder="Enter the entities to detect here (comma separated)",
85
  scale=2,
@@ -87,7 +88,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
87
  threshold = gr.Slider(
88
  0,
89
  1,
90
- value=0.5,
91
  step=0.01,
92
  label="Threshold",
93
  info="Lower threshold may extract more false-positive entities from the query.",
 
7
 
8
  _MODEL = {}
9
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
10
+ LABELS = ["country", "year", "statistical indicator", "geographic region"]
11
+ QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
12
 
13
  print(f"Cache directory: {_CACHE_DIR}")
14
 
15
 
16
  def get_model(model_name: str = None):
17
  if model_name is None:
18
+ model_name = "urchade/gliner_base"
 
19
 
20
  global _MODEL
21
 
 
32
  return None
33
 
34
 
35
+ def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
36
  model = get_model(model_name)
37
 
38
  if isinstance(labels, str):
 
71
  )
72
 
73
  query = gr.Textbox(
74
+ value=QUERY, label="query", placeholder="Enter your query here"
75
  )
76
  with gr.Row() as row:
77
  model_name = gr.Radio(
78
+ choices=["urchade/gliner_base", "urchade/gliner_medium-v2.1"],
79
+ value="urchade/gliner_base",
80
  label="Model",
81
  )
82
  entities = gr.Textbox(
83
+ value=", ".join(LABELS),
84
  label="entities",
85
  placeholder="Enter the entities to detect here (comma separated)",
86
  scale=2,
 
88
  threshold = gr.Slider(
89
  0,
90
  1,
91
+ value=0.3,
92
  step=0.01,
93
  label="Threshold",
94
  info="Lower threshold may extract more false-positive entities from the query.",