avsolatorio commited on
Commit
49fed2a
1 Parent(s): 2d67881

Initialize models

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -11,8 +11,10 @@ from gliner import GLiNER
11
 
12
  _MODEL = {}
13
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
 
14
  LABELS = ["country", "year", "statistical indicator", "geographic region"]
15
  QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
 
16
 
17
  print(f"Cache directory: {_CACHE_DIR}")
18
 
@@ -36,6 +38,13 @@ def get_model(model_name: str = None):
36
  return _MODEL[model_name]
37
 
38
 
 
 
 
 
 
 
 
39
  def get_country(country_name: str):
40
  try:
41
  return pycountry.countries.search_fuzzy(country_name)
@@ -43,7 +52,7 @@ def get_country(country_name: str):
43
  return None
44
 
45
 
46
- @spaces.GPU(enable_queue=True, duration=3)
47
  def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
48
  start = datetime.now()
49
  model = get_model(model_name)
@@ -99,7 +108,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
99
  )
100
  with gr.Row() as row:
101
  model_name = gr.Radio(
102
- choices=["urchade/gliner_base", "urchade/gliner_medium-v2.1"],
103
  value="urchade/gliner_base",
104
  label="Model",
105
  )
@@ -112,7 +121,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
112
  threshold = gr.Slider(
113
  0,
114
  1,
115
- value=0.3,
116
  step=0.01,
117
  label="Threshold",
118
  info="Lower threshold may extract more false-positive entities from the query.",
 
11
 
12
  _MODEL = {}
13
  _CACHE_DIR = os.environ.get("CACHE_DIR", None)
14
+ THRESHOLD = 0.3
15
  LABELS = ["country", "year", "statistical indicator", "geographic region"]
16
  QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
17
+ MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
18
 
19
  print(f"Cache directory: {_CACHE_DIR}")
20
 
 
38
  return _MODEL[model_name]
39
 
40
 
41
+ # Initialize model here.
42
+ print("Initializing models...")
43
+ for model_name in MODELS:
44
+ model = get_model(model_name=model_name)
45
+ model.predict_entities(QUERY, LABELS, threshold=THRESHOLD)
46
+
47
+
48
  def get_country(country_name: str):
49
  try:
50
  return pycountry.countries.search_fuzzy(country_name)
 
52
  return None
53
 
54
 
55
+ @spaces.GPU(enable_queue=True, duration=5)
56
  def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
57
  start = datetime.now()
58
  model = get_model(model_name)
 
108
  )
109
  with gr.Row() as row:
110
  model_name = gr.Radio(
111
+ choices=MODELS,
112
  value="urchade/gliner_base",
113
  label="Model",
114
  )
 
121
  threshold = gr.Slider(
122
  0,
123
  1,
124
+ value=THRESHOLD,
125
  step=0.01,
126
  label="Threshold",
127
  info="Lower threshold may extract more false-positive entities from the query.",