sam-pointer-bart-base-v0.3 / rendering_utils.py
ArneBinder's picture
Upload 10 files
1681237 verified
raw
history blame contribute delete
No virus
4.12 kB
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Union
from annotation_utils import labeled_span_to_id
from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
def render_pretty_table(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
spacy_doc = {
"text": document.text,
"ents": [
{"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
],
"title": None,
}
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
binary_relations = list(document.binary_relations) + list(
document.binary_relations.predictions
)
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
html = inject_relation_data(
html,
sorted_entities=sorted_entities,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
sorted_entities,
binary_relations: List[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for idx, entity in enumerate(entities):
annotation = sorted_entities[idx]
entity["id"] = labeled_span_to_id(annotation)
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = sorted_entities[idx]
# sanity check
if str(entity_annotation) != entity.next:
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": labeled_span_to_id(tail), "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": labeled_span_to_id(head), "label": label}
for head, label in entity2heads.get(entity_annotation, [])
]
)
# Return the modified HTML as a string
return str(soup)