avsolatorio commited on
Commit
81fbf66
1 Parent(s): 71ef356

Add time logs and duration

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -36,20 +36,24 @@ def get_country(country_name: str):
36
  return None
37
 
38
 
39
- @spaces.GPU
40
- def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
41
- model = get_model(model_name)
42
 
43
  if isinstance(labels, str):
44
  labels = [i.strip() for i in labels.split(",")]
45
 
46
- return model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
 
 
 
 
47
 
48
 
49
  def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
50
 
51
  entities = []
52
- _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
53
 
54
  for entity in _entities:
55
  if entity["label"] == "country":
@@ -61,7 +65,7 @@ def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, ne
61
  entities.append(entity)
62
 
63
  payload = {"query": query, "entities": entities}
64
- print(f"{datetime.now()} :: {json.dumps(payload)}\n")
65
 
66
  return payload
67
 
 
36
  return None
37
 
38
 
39
+ @spaces.GPU(duration=3)
40
+ def predict_entities(model, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
41
+ start = datetime.now()
42
 
43
  if isinstance(labels, str):
44
  labels = [i.strip() for i in labels.split(",")]
45
 
46
+ entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
47
+
48
+ print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")
49
+
50
+ return entities
51
 
52
 
53
  def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
54
 
55
  entities = []
56
+ _entities = predict_entities(model=get_model(model_name), query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
57
 
58
  for entity in _entities:
59
  if entity["label"] == "country":
 
65
  entities.append(entity)
66
 
67
  payload = {"query": query, "entities": entities}
68
+ print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")
69
 
70
  return payload
71