import json
from collections import defaultdict
from typing import Dict, List, Optional, Union
from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer
def render_pretty_table(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "
" + html + "
"
return html
def render_displacy(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
spacy_doc = {
"text": document.text,
"ents": [
{"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
],
"title": None,
}
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "" + html + "
"
if inject_relations:
binary_relations = list(document.binary_relations) + list(
document.binary_relations.predictions
)
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
html = inject_relation_data(
html,
sorted_entities=sorted_entities,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def labeled_span_to_id(span: LabeledSpan) -> str:
return f"span-{span.start}-{span.end}-{span.label}"
def labeled_span_from_id(span_id: str) -> LabeledSpan:
parts = span_id.split("-")
return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
def inject_relation_data(
html: str,
sorted_entities,
binary_relations: List[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for idx, entity in enumerate(entities):
annotation = sorted_entities[idx]
entity["id"] = labeled_span_to_id(annotation)
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = sorted_entities[idx]
# sanity check
if str(entity_annotation) != entity.next:
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": labeled_span_to_id(tail), "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": labeled_span_to_id(head), "label": label}
for head, label in entity2heads.get(entity_annotation, [])
]
)
# Return the modified HTML as a string
return str(soup)