speqtr commited on
Commit
a88e758
1 Parent(s): 49b42fc

custom endpoint handler

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. handler.py +49 -0
  3. handler_test.py +10 -0
  4. requirements-dev.txt +2 -0
  5. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea/
2
+ .pytest_cache/
3
+ .venv/
handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+
7
+ def _load_spacy_model(name: str | Path) -> spacy.Language:
8
+ if not hasattr(_load_spacy_model, "nlp"):
9
+ # pipeline info https://spacy.io/models/en#en_core_web_lg
10
+ all_except_ner: list[str] = [
11
+ "tok2vec",
12
+ "tagger",
13
+ "parser",
14
+ "attribute_ruler",
15
+ "lemmatizer"]
16
+
17
+ nlp = spacy.load(name=name, exclude=all_except_ner)
18
+ _load_spacy_model.nlp = nlp
19
+ print(f"Loaded {nlp.meta.get('name', 'unknown')} model from {nlp.path}")
20
+
21
+ return _load_spacy_model.nlp
22
+
23
+
24
+ class EndpointHandler:
25
+ def __init__(self, name: str | Path = "en_core_web_lg"):
26
+ self._nlp: spacy.Language = _load_spacy_model(name=name)
27
+
28
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
29
+ inputs: str = data.pop("inputs", "")
30
+ if not inputs:
31
+ return {}
32
+
33
+ outputs: list[dict[str, Any]] = []
34
+
35
+ doc = self._nlp(text=inputs)
36
+ for ent in doc.ents:
37
+ if ent.label_ != "PERSON":
38
+ continue
39
+
40
+ entity: dict = {
41
+ "qid": None,
42
+ "entity": ent.label_,
43
+ "text": ent.text,
44
+ "start": ent.start_char,
45
+ "end": ent.end_char}
46
+
47
+ outputs.append(entity)
48
+
49
+ return {"outputs": outputs}
handler_test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+
4
+ class TestEndpointHandler:
5
+ _handler: EndpointHandler = EndpointHandler()
6
+
7
+ def test_endpoint_handler(self):
8
+ result: dict = self._handler(data={"inputs": "Who is John Doe?"})
9
+ outputs: list = result.get("outputs", [])
10
+ assert len(outputs) == 1
requirements-dev.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -r ./requirements.txt
2
+ pytest
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ spacy==3.4.0
2
+
3
+ # trained pipeline
4
+ https://huggingface.co/spacy/en_core_web_lg/resolve/main/en_core_web_lg-any-py3-none-any.whl