speqtr commited on
Commit
b2e35f0
1 Parent(s): 83c8492

configured ner with better accuracy

Browse files
introduck/inference.py CHANGED
@@ -1,7 +1,22 @@
1
  import spacy
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def extract_contacts_from_email(payload: str) -> list[list[str]]:
 
 
 
5
  return [["John Doe", "Acme Corp.", "john@example.com", "ai, nn"]]
6
 
7
 
@@ -14,18 +29,13 @@ def generate_rejection_reply(payload: str) -> str:
14
 
15
 
16
  def highlight_named_entities(payload: str, labels: list[str] = None) -> dict:
 
 
 
17
  if labels is None:
18
  labels = ["ORG", "PERSON"]
19
 
20
- nlp: spacy.Language | None = None
21
- if not hasattr(highlight_named_entities, "nlp"):
22
- nlp = spacy.load(name="en_core_web_sm")
23
- highlight_named_entities.nlp = nlp
24
- print(f"Loaded {nlp.meta.get('name', 'unknown')} model from {nlp.path}")
25
- else:
26
- nlp = highlight_named_entities.nlp
27
- print(f"Reused {nlp.meta.get('name', 'unknown')} model from {nlp.path}")
28
-
29
  doc = nlp(payload)
30
 
31
  entities: list = []
 
1
  import spacy
2
 
3
 
4
+ def _load_spacy_model() -> spacy.Language:
5
+ spacy_model_name: str = "en_core_web_trf"
6
+
7
+ if not hasattr(_load_spacy_model, "nlp"):
8
+ nlp = spacy.load(name=spacy_model_name)
9
+ print(f"Loaded {nlp.meta.get('name', 'unknown')} model from {nlp.path}")
10
+
11
+ _load_spacy_model.nlp = nlp
12
+
13
+ return _load_spacy_model.nlp
14
+
15
+
16
  def extract_contacts_from_email(payload: str) -> list[list[str]]:
17
+ if not payload:
18
+ return []
19
+
20
  return [["John Doe", "Acme Corp.", "john@example.com", "ai, nn"]]
21
 
22
 
 
29
 
30
 
31
  def highlight_named_entities(payload: str, labels: list[str] = None) -> dict:
32
+ if not payload:
33
+ return {"text": "", "entities": []}
34
+
35
  if labels is None:
36
  labels = ["ORG", "PERSON"]
37
 
38
+ nlp: spacy.Language = _load_spacy_model()
 
 
 
 
 
 
 
 
39
  doc = nlp(payload)
40
 
41
  entities: list = []
introduck/routes.py CHANGED
@@ -7,7 +7,7 @@ from introduck.inference import extract_contacts_from_email
7
  from introduck.inference import generate_acceptance_reply
8
  from introduck.inference import generate_rejection_reply
9
  from introduck.inference import highlight_named_entities
10
- from introduck.utils import create_email_message
11
 
12
  _INTRO_SUBJECT_EXAMPLE: str = "Could you make an intro?"
13
  _INTRO_MESSAGE_EXAMPLE: str = """\
@@ -41,14 +41,14 @@ def _analyze_message(sender: str, recipients: str, subject: str, body: str):
41
  "subject": subject,
42
  "body": body}
43
 
44
- msg: str = create_email_message(data=msg_data)
45
 
46
  contacts: list = extract_contacts_from_email(payload=msg)
47
  acceptance_reply: str = generate_acceptance_reply(payload=msg)
48
  rejection_reply: str = generate_rejection_reply(payload=msg)
49
- text_with_named_entities: dict = highlight_named_entities(payload=msg)
50
 
51
- return contacts, acceptance_reply, rejection_reply, text_with_named_entities
52
 
53
 
54
  def _use_message_template() -> (str, str):
@@ -58,7 +58,12 @@ def _use_message_template() -> (str, str):
58
  def create_base_route(title: str = "") -> FastAPI:
59
  # TODO: Fix once resolved https://github.com/gradio-app/gradio/issues/1683
60
  # TODO: ...and remove elem_id="htxt" from HighlightedText for RF822 output
61
- css_workaround: str = "#htxt span {white-space: pre}"
 
 
 
 
 
62
 
63
  base_blocks: gr.Blocks = gr.Blocks(
64
  analytics_enabled=False,
 
7
  from introduck.inference import generate_acceptance_reply
8
  from introduck.inference import generate_rejection_reply
9
  from introduck.inference import highlight_named_entities
10
+ from introduck.utils import dump_email_as_string
11
 
12
  _INTRO_SUBJECT_EXAMPLE: str = "Could you make an intro?"
13
  _INTRO_MESSAGE_EXAMPLE: str = """\
 
41
  "subject": subject,
42
  "body": body}
43
 
44
+ msg: str = dump_email_as_string(data=msg_data)
45
 
46
  contacts: list = extract_contacts_from_email(payload=msg)
47
  acceptance_reply: str = generate_acceptance_reply(payload=msg)
48
  rejection_reply: str = generate_rejection_reply(payload=msg)
49
+ highlighted_text: dict = highlight_named_entities(payload=msg)
50
 
51
+ return contacts, acceptance_reply, rejection_reply, highlighted_text
52
 
53
 
54
  def _use_message_template() -> (str, str):
 
58
  def create_base_route(title: str = "") -> FastAPI:
59
  # TODO: Fix once resolved https://github.com/gradio-app/gradio/issues/1683
60
  # TODO: ...and remove elem_id="htxt" from HighlightedText for RF822 output
61
+ css_workaround: str = """
62
+ #htxt span {
63
+ font-family: monospace;
64
+ white-space: pre;
65
+ }
66
+ """
67
 
68
  base_blocks: gr.Blocks = gr.Blocks(
69
  analytics_enabled=False,
introduck/utils.py CHANGED
@@ -1,16 +1,13 @@
1
- from email.message import EmailMessage
2
-
3
-
4
- def create_email_message(data: dict) -> str:
5
- msg: EmailMessage = EmailMessage()
6
-
7
- body_text: str = data.get("body", "")
8
- # body_html: str = ""
9
-
10
- msg["From"] = data.get("from", "")
11
- msg["To"] = data.get("to", "")
12
- msg["Subject"] = data.get("subject", "")
13
- msg.set_content(body_text)
14
- # msg.add_alternative(body_html, subtype="html")
15
-
16
- return msg.as_string()
 
1
+ def dump_email_as_string(data: dict) -> str:
2
+ msg_from: str = data.get("from", "") or "*"
3
+ msg_to: str = data.get("to", "") or "*"
4
+ msg_subject: str = data.get("subject", "") or "*"
5
+ msg_body: str = data.get("body", "") or "***"
6
+
7
+ msg: str = ""
8
+ msg += f"From: {msg_from}\n"
9
+ msg += f"To: {msg_to}\n"
10
+ msg += f"Subject: {msg_subject}\n\n"
11
+ msg += f"{msg_body}\n"
12
+
13
+ return msg
 
 
 
requirements.txt CHANGED
@@ -4,3 +4,4 @@ spacy==3.4.0
4
  uvicorn[standard]
5
 
6
  https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
 
 
4
  uvicorn[standard]
5
 
6
  https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
7
+ https://huggingface.co/spacy/en_core_web_trf/resolve/main/en_core_web_trf-any-py3-none-any.whl