File size: 3,984 Bytes
bc6f57a bfcba2d bc6f57a 5003662 bc6f57a 5003662 bc6f57a 5003662 bc6f57a 25fcabc bc6f57a 25fcabc bc6f57a 25fcabc bc6f57a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import json
from collections import defaultdict
from typing import Dict, List, Optional, Union
from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer
def render_pretty_table(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
spacy_doc = {
"text": document.text,
"ents": [
{"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
],
"title": None,
}
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
binary_relations = list(document.binary_relations) + list(
document.binary_relations.predictions
)
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
html = inject_relation_data(
html,
sorted_entities=sorted_entities,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
sorted_entities,
binary_relations: List[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for idx, entity in enumerate(entities):
annotation = sorted_entities[idx]
entity["id"] = str(annotation._id)
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = sorted_entities[idx]
# sanity check
if str(entity_annotation) != entity.next:
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": str(tail._id), "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": str(head._id), "label": label}
for head, label in entity2heads.get(entity_annotation, [])
]
)
# Return the modified HTML as a string
return str(soup)
|