ArneBinder's picture
update from https://github.com/ArneBinder/pie-document-level/pull/397
ced4316 verified
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Union
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from .rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
RENDER_WITH_DISPLACY = "displacy"
RENDER_WITH_PRETTY_TABLE = "pretty_table"
AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE]
# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" data-highlight-mode="{highlight_mode}" style="background: {bg}; border-width: {border_width}; border-color: {border_color}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
{text}
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
</mark>
"""
HIGHLIGHT_SPANS_JS = """
() => {
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
var color = entity.getAttribute('data-color-' + colorAttributeKey);
// if color is a json string, parse it and use the value at colorDictKey
try {
const colors = JSON.parse(color);
color = colors[colorDictKey];
} catch (e) {}
if (color) {
//const highlightMode = entity.getAttribute('data-highlight-mode');
//if (highlightMode === 'fill') {
entity.style.backgroundColor = color;
entity.style.color = '#000';
//}
entity.style.borderColor = color;
}
}
function highlightRelationArguments(entityId) {
const entities = document.querySelectorAll('.entity');
// reset all entities
entities.forEach(entity => {
const color = entity.getAttribute('data-color-original');
entity.style.backgroundColor = color;
const borderColor = entity.getAttribute('data-border-color-original');
entity.style.borderColor = borderColor;
entity.style.color = '';
});
if (entityId !== null) {
var visitedEntities = new Set();
// highlight selected entity
// get all elements with attribute data-entity-id==entityId
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
selectedEntityParts.forEach(selectedEntityPart => {
const label = selectedEntityPart.getAttribute('data-label');
maybeSetColor(selectedEntityPart, 'selected', label);
visitedEntities.add(selectedEntityPart);
}); // <-- Corrected closing parenthesis here
// if there is at least one part, get the first one and ...
if (selectedEntityParts.length > 0) {
const selectedEntity = selectedEntityParts[0];
// ... highlight tails and ...
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
relationTailsAndLabels.forEach(relationTail => {
const tailEntityId = relationTail['entity-id'];
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
tailEntityParts.forEach(tailEntity => {
const label = relationTail['label'];
maybeSetColor(tailEntity, 'tail', label);
visitedEntities.add(tailEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
// .. highlight heads
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
relationHeadsAndLabels.forEach(relationHead => {
const headEntityId = relationHead['entity-id'];
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
headEntityParts.forEach(headEntity => {
const label = relationHead['label'];
maybeSetColor(headEntity, 'head', label);
visitedEntities.add(headEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
}
// highlight other entities
entities.forEach(entity => {
if (!visitedEntities.has(entity)) {
const label = entity.getAttribute('data-label');
maybeSetColor(entity, 'other', label);
}
});
}
}
function setHoverAduId(entityId) {
// get the textarea element that holds the reference adu id
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// set the value of the input field
hoverAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
hoverAduIdDiv.dispatchEvent(event);
}
function setReferenceAduIdFromHover() {
// get the hover adu id
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// get the value of the input field
const entityId = hoverAduIdDiv.value;
// get the textarea element that holds the reference adu id
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
// set the value of the input field
referenceAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
referenceAduIdDiv.dispatchEvent(event);
}
const entities = document.querySelectorAll('.entity');
entities.forEach(entity => {
// make the cursor a pointer
entity.style.cursor = 'pointer';
const alreadyHasListener = entity.getAttribute('data-has-listener');
if (alreadyHasListener) {
return;
}
entity.addEventListener('mouseover', () => {
const entityId = entity.getAttribute('data-entity-id');
highlightRelationArguments(entityId);
setHoverAduId(entityId);
});
entity.addEventListener('mouseout', () => {
highlightRelationArguments(null);
});
entity.setAttribute('data-has-listener', 'true');
});
const entityContainer = document.querySelector('.entities');
if (entityContainer) {
entityContainer.addEventListener('click', () => {
setReferenceAduIdFromHover();
});
// make the cursor a pointer
// entityContainer.style.cursor = 'pointer';
}
}
"""
def render_pretty_table(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
**render_kwargs,
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(binary_relations) + list(binary_relations):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
highlight_span_ids: Optional[List[str]] = None,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
ents: List[Dict[str, Any]] = []
for entity_id, idx in span_id2idx.items():
labeled_span = spans[idx]
highlight_mode = (
"fill" if highlight_span_ids is None or entity_id in highlight_span_ids else "border"
)
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
# on hover and to inject the relation data.
if isinstance(labeled_span, LabeledSpan):
ents.append(
{
"start": labeled_span.start,
"end": labeled_span.end,
"label": labeled_span.label,
"params": {
"entity_id": entity_id,
"slice_idx": 0,
"highlight_mode": highlight_mode,
},
}
)
elif isinstance(labeled_span, LabeledMultiSpan):
for i, (start, end) in enumerate(labeled_span.slices):
ents.append(
{
"start": start,
"end": end,
"label": labeled_span.label,
"params": {
"entity_id": entity_id,
"slice_idx": i,
"highlight_mode": highlight_mode,
},
}
)
else:
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"]))
spacy_doc = {
"text": text,
# the ents MUST be sorted by start and end
"ents": ents_sorted,
"title": None,
}
# copy to avoid modifying the original options
entity_options = entity_options.copy()
# use the custom template with the entity ID
entity_options["template"] = TPL_ENT_WITH_ID
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
html = inject_relation_data(
html,
spans=spans,
span_id2idx=span_id2idx,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for entity in entities:
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
original_border_color = entity["style"].split("border-color:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
entity["data-border-color-original"] = original_border_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
# sanity check.
if isinstance(entity_annotation, LabeledSpan):
annotation_text = entity_annotation.resolve()[1]
elif isinstance(entity_annotation, LabeledMultiSpan):
slice_idx = int(entity["data-slice-idx"])
annotation_text = entity_annotation.resolve()[1][slice_idx]
else:
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
annotation_text_without_newline = annotation_text.replace("\n", "")
# Just check the start, because the text has the label attached to the end
if not entity.text.startswith(annotation_text_without_newline):
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": annotation2id[tail], "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
if tail in annotation2id
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": annotation2id[head], "label": label}
for head, label in entity2heads.get(entity_annotation, [])
if head in annotation2id
]
)
# Return the modified HTML as a string
return str(soup)