import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Union
from annotation_utils import labeled_span_to_id
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
{text}
{label}
"""
def render_pretty_table(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "
" + html + "
"
return html
def render_displacy(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
spacy_doc = {
"text": document.text,
"ents": [
{
"start": labeled_span.start,
"end": labeled_span.end,
"label": labeled_span.label,
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
# on hover and to inject the relation data.
"params": {"id": labeled_span_to_id(labeled_span)},
}
for labeled_span in labeled_spans
],
"title": None,
}
# copy to avoid modifying the original options
entity_options = entity_options.copy()
# use the custom template with the entity ID
entity_options["template"] = TPL_ENT_WITH_ID
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "" + html + "
"
if inject_relations:
binary_relations = list(document.binary_relations) + list(
document.binary_relations.predictions
)
html = inject_relation_data(
html,
labeled_spans=labeled_spans,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
labeled_spans: List[LabeledSpan],
binary_relations: List[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for entity in entities:
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = ann_id2annotation[entity["id"]]
# sanity check.
annotation_text_without_newline = str(entity_annotation).replace("\n", "")
# Just check the start, because the text has the label attached to the end
if not entity.text.startswith(annotation_text_without_newline):
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": labeled_span_to_id(tail), "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": labeled_span_to_id(head), "label": label}
for head, label in entity2heads.get(entity_annotation, [])
]
)
# Return the modified HTML as a string
return str(soup)