avsolatorio commited on
Commit
d4e7f01
·
1 Parent(s): aadf93b

Use zerogpu

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (2) hide show
  1. app.py +9 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import json
3
  import gradio as gr
@@ -23,6 +24,7 @@ def get_model(model_name: str = None):
23
 
24
  if _MODEL.get(model_name) is None:
25
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
 
26
 
27
  return _MODEL[model_name]
28
 
@@ -34,15 +36,20 @@ def get_country(country_name: str):
34
  return None
35
 
36
 
37
- def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
 
38
  model = get_model(model_name)
39
 
40
  if isinstance(labels, str):
41
  labels = [i.strip() for i in labels.split(",")]
42
 
43
- _entities = model.predict_entities(query, labels, threshold=threshold)
 
 
 
44
 
45
  entities = []
 
46
 
47
  for entity in _entities:
48
  if entity["label"] == "country":
 
1
+ import spaces
2
  import os
3
  import json
4
  import gradio as gr
 
24
 
25
  if _MODEL.get(model_name) is None:
26
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
27
+ _MODEL[model_name].to("cuda")
28
 
29
  return _MODEL[model_name]
30
 
 
36
  return None
37
 
38
 
39
+ @spaces.GPU
40
+ def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None):
41
  model = get_model(model_name)
42
 
43
  if isinstance(labels, str):
44
  labels = [i.strip() for i in labels.split(",")]
45
 
46
+ return model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
47
+
48
+
49
+ def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
50
 
51
  entities = []
52
+ _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
53
 
54
  for entity in _entities:
55
  if entity["label"] == "country":
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gliner
2
  pycountry
3
  scipy==1.12
4
- gradio
 
 
1
  gliner
2
  pycountry
3
  scipy==1.12
4
+ gradio
5
+ spaces