File size: 3,984 Bytes
bc6f57a
 
 
 
 
 
bfcba2d
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5003662
bc6f57a
 
 
5003662
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
5003662
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25fcabc
 
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
25fcabc
bc6f57a
 
 
 
 
25fcabc
bc6f57a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
from collections import defaultdict
from typing import Dict, List, Optional, Union

from pytorch_ie.annotations import BinaryRelation
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer


def render_pretty_table(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
    from prettytable import PrettyTable

    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"

    return html


def render_displacy(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    inject_relations=True,
    colors_hover=None,
    entity_options={},
    **render_kwargs,
):

    spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
    spacy_doc = {
        "text": document.text,
        "ents": [
            {"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
        ],
        "title": None,
    }

    renderer = EntityRenderer(options=entity_options)
    html = renderer.render([spacy_doc], page=True, minify=True).strip()

    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
    if inject_relations:
        binary_relations = list(document.binary_relations) + list(
            document.binary_relations.predictions
        )
        sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
        html = inject_relation_data(
            html,
            sorted_entities=sorted_entities,
            binary_relations=binary_relations,
            additional_colors=colors_hover,
        )
    return html


def inject_relation_data(
    html: str,
    sorted_entities,
    binary_relations: List[BinaryRelation],
    additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
    from bs4 import BeautifulSoup

    # Parse the HTML using BeautifulSoup
    soup = BeautifulSoup(html, "html.parser")

    entity2tails = defaultdict(list)
    entity2heads = defaultdict(list)
    for relation in binary_relations:
        entity2heads[relation.tail].append((relation.head, relation.label))
        entity2tails[relation.head].append((relation.tail, relation.label))

    # Add unique IDs to each entity
    entities = soup.find_all(class_="entity")
    for idx, entity in enumerate(entities):
        annotation = sorted_entities[idx]
        entity["id"] = str(annotation._id)
        original_color = entity["style"].split("background:")[1].split(";")[0].strip()
        entity["data-color-original"] = original_color
        if additional_colors is not None:
            for key, color in additional_colors.items():
                entity[f"data-color-{key}"] = (
                    json.dumps(color) if isinstance(color, dict) else color
                )
        entity_annotation = sorted_entities[idx]
        # sanity check
        if str(entity_annotation) != entity.next:
            raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
        entity["data-label"] = entity_annotation.label
        entity["data-relation-tails"] = json.dumps(
            [
                {"entity-id": str(tail._id), "label": label}
                for tail, label in entity2tails.get(entity_annotation, [])
            ]
        )
        entity["data-relation-heads"] = json.dumps(
            [
                {"entity-id": str(head._id), "label": label}
                for head, label in entity2heads.get(entity_annotation, [])
            ]
        )

    # Return the modified HTML as a string
    return str(soup)