|
import json |
|
import logging |
|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional, Sequence, Union |
|
|
|
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan |
|
|
|
from .rendering_utils_displacy import EntityRenderer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
RENDER_WITH_DISPLACY = "displacy" |
|
RENDER_WITH_PRETTY_TABLE = "pretty_table" |
|
AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE] |
|
|
|
|
|
TPL_ENT_WITH_ID = """ |
|
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" data-highlight-mode="{highlight_mode}" style="background: {bg}; border-width: {border_width}; border-color: {border_color}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;"> |
|
{text} |
|
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span> |
|
</mark> |
|
""" |
|
|
|
HIGHLIGHT_SPANS_JS = """ |
|
() => { |
|
function maybeSetColor(entity, colorAttributeKey, colorDictKey) { |
|
var color = entity.getAttribute('data-color-' + colorAttributeKey); |
|
// if color is a json string, parse it and use the value at colorDictKey |
|
try { |
|
const colors = JSON.parse(color); |
|
color = colors[colorDictKey]; |
|
} catch (e) {} |
|
if (color) { |
|
//const highlightMode = entity.getAttribute('data-highlight-mode'); |
|
//if (highlightMode === 'fill') { |
|
entity.style.backgroundColor = color; |
|
entity.style.color = '#000'; |
|
//} |
|
entity.style.borderColor = color; |
|
} |
|
} |
|
|
|
function highlightRelationArguments(entityId) { |
|
const entities = document.querySelectorAll('.entity'); |
|
// reset all entities |
|
entities.forEach(entity => { |
|
const color = entity.getAttribute('data-color-original'); |
|
entity.style.backgroundColor = color; |
|
const borderColor = entity.getAttribute('data-border-color-original'); |
|
entity.style.borderColor = borderColor; |
|
entity.style.color = ''; |
|
}); |
|
|
|
if (entityId !== null) { |
|
var visitedEntities = new Set(); |
|
// highlight selected entity |
|
// get all elements with attribute data-entity-id==entityId |
|
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`); |
|
selectedEntityParts.forEach(selectedEntityPart => { |
|
const label = selectedEntityPart.getAttribute('data-label'); |
|
maybeSetColor(selectedEntityPart, 'selected', label); |
|
visitedEntities.add(selectedEntityPart); |
|
}); // <-- Corrected closing parenthesis here |
|
// if there is at least one part, get the first one and ... |
|
if (selectedEntityParts.length > 0) { |
|
const selectedEntity = selectedEntityParts[0]; |
|
|
|
// ... highlight tails and ... |
|
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails')); |
|
relationTailsAndLabels.forEach(relationTail => { |
|
const tailEntityId = relationTail['entity-id']; |
|
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`); |
|
tailEntityParts.forEach(tailEntity => { |
|
const label = relationTail['label']; |
|
maybeSetColor(tailEntity, 'tail', label); |
|
visitedEntities.add(tailEntity); |
|
}); // <-- Corrected closing parenthesis here |
|
}); // <-- Corrected closing parenthesis here |
|
// .. highlight heads |
|
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads')); |
|
relationHeadsAndLabels.forEach(relationHead => { |
|
const headEntityId = relationHead['entity-id']; |
|
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`); |
|
headEntityParts.forEach(headEntity => { |
|
const label = relationHead['label']; |
|
maybeSetColor(headEntity, 'head', label); |
|
visitedEntities.add(headEntity); |
|
}); // <-- Corrected closing parenthesis here |
|
}); // <-- Corrected closing parenthesis here |
|
} |
|
|
|
// highlight other entities |
|
entities.forEach(entity => { |
|
if (!visitedEntities.has(entity)) { |
|
const label = entity.getAttribute('data-label'); |
|
maybeSetColor(entity, 'other', label); |
|
} |
|
}); |
|
} |
|
} |
|
function setHoverAduId(entityId) { |
|
// get the textarea element that holds the reference adu id |
|
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); |
|
// set the value of the input field |
|
hoverAduIdDiv.value = entityId; |
|
// trigger an input event to update the state |
|
var event = new Event('input'); |
|
hoverAduIdDiv.dispatchEvent(event); |
|
} |
|
function setReferenceAduIdFromHover() { |
|
// get the hover adu id |
|
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); |
|
// get the value of the input field |
|
const entityId = hoverAduIdDiv.value; |
|
// get the textarea element that holds the reference adu id |
|
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea'); |
|
// set the value of the input field |
|
referenceAduIdDiv.value = entityId; |
|
// trigger an input event to update the state |
|
var event = new Event('input'); |
|
referenceAduIdDiv.dispatchEvent(event); |
|
} |
|
|
|
const entities = document.querySelectorAll('.entity'); |
|
entities.forEach(entity => { |
|
// make the cursor a pointer |
|
entity.style.cursor = 'pointer'; |
|
const alreadyHasListener = entity.getAttribute('data-has-listener'); |
|
if (alreadyHasListener) { |
|
return; |
|
} |
|
entity.addEventListener('mouseover', () => { |
|
const entityId = entity.getAttribute('data-entity-id'); |
|
highlightRelationArguments(entityId); |
|
setHoverAduId(entityId); |
|
}); |
|
entity.addEventListener('mouseout', () => { |
|
highlightRelationArguments(null); |
|
}); |
|
entity.setAttribute('data-has-listener', 'true'); |
|
}); |
|
const entityContainer = document.querySelector('.entities'); |
|
if (entityContainer) { |
|
entityContainer.addEventListener('click', () => { |
|
setReferenceAduIdFromHover(); |
|
}); |
|
// make the cursor a pointer |
|
// entityContainer.style.cursor = 'pointer'; |
|
} |
|
} |
|
""" |
|
|
|
|
|
def render_pretty_table( |
|
text: str, |
|
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], |
|
span_id2idx: Dict[str, int], |
|
binary_relations: Sequence[BinaryRelation], |
|
**render_kwargs, |
|
): |
|
from prettytable import PrettyTable |
|
|
|
t = PrettyTable() |
|
t.field_names = ["head", "tail", "relation"] |
|
t.align = "l" |
|
for relation in list(binary_relations) + list(binary_relations): |
|
t.add_row([str(relation.head), str(relation.tail), relation.label]) |
|
|
|
html = t.get_html_string(format=True) |
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
|
|
return html |
|
|
|
|
|
def render_displacy( |
|
text: str, |
|
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], |
|
span_id2idx: Dict[str, int], |
|
binary_relations: Sequence[BinaryRelation], |
|
highlight_span_ids: Optional[List[str]] = None, |
|
inject_relations=True, |
|
colors_hover=None, |
|
entity_options={}, |
|
**render_kwargs, |
|
): |
|
|
|
ents: List[Dict[str, Any]] = [] |
|
for entity_id, idx in span_id2idx.items(): |
|
labeled_span = spans[idx] |
|
highlight_mode = ( |
|
"fill" if highlight_span_ids is None or entity_id in highlight_span_ids else "border" |
|
) |
|
|
|
|
|
if isinstance(labeled_span, LabeledSpan): |
|
ents.append( |
|
{ |
|
"start": labeled_span.start, |
|
"end": labeled_span.end, |
|
"label": labeled_span.label, |
|
"params": { |
|
"entity_id": entity_id, |
|
"slice_idx": 0, |
|
"highlight_mode": highlight_mode, |
|
}, |
|
} |
|
) |
|
elif isinstance(labeled_span, LabeledMultiSpan): |
|
for i, (start, end) in enumerate(labeled_span.slices): |
|
ents.append( |
|
{ |
|
"start": start, |
|
"end": end, |
|
"label": labeled_span.label, |
|
"params": { |
|
"entity_id": entity_id, |
|
"slice_idx": i, |
|
"highlight_mode": highlight_mode, |
|
}, |
|
} |
|
) |
|
else: |
|
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}") |
|
|
|
ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"])) |
|
spacy_doc = { |
|
"text": text, |
|
|
|
"ents": ents_sorted, |
|
"title": None, |
|
} |
|
|
|
|
|
entity_options = entity_options.copy() |
|
|
|
entity_options["template"] = TPL_ENT_WITH_ID |
|
renderer = EntityRenderer(options=entity_options) |
|
html = renderer.render([spacy_doc], page=True, minify=True).strip() |
|
|
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
if inject_relations: |
|
html = inject_relation_data( |
|
html, |
|
spans=spans, |
|
span_id2idx=span_id2idx, |
|
binary_relations=binary_relations, |
|
additional_colors=colors_hover, |
|
) |
|
return html |
|
|
|
|
|
def inject_relation_data( |
|
html: str, |
|
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], |
|
span_id2idx: Dict[str, int], |
|
binary_relations: Sequence[BinaryRelation], |
|
additional_colors: Optional[Dict[str, Union[str, dict]]] = None, |
|
) -> str: |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
soup = BeautifulSoup(html, "html.parser") |
|
|
|
entity2tails = defaultdict(list) |
|
entity2heads = defaultdict(list) |
|
for relation in binary_relations: |
|
entity2heads[relation.tail].append((relation.head, relation.label)) |
|
entity2tails[relation.head].append((relation.tail, relation.label)) |
|
|
|
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()} |
|
|
|
entities = soup.find_all(class_="entity") |
|
for entity in entities: |
|
original_color = entity["style"].split("background:")[1].split(";")[0].strip() |
|
original_border_color = entity["style"].split("border-color:")[1].split(";")[0].strip() |
|
entity["data-color-original"] = original_color |
|
entity["data-border-color-original"] = original_border_color |
|
if additional_colors is not None: |
|
for key, color in additional_colors.items(): |
|
entity[f"data-color-{key}"] = ( |
|
json.dumps(color) if isinstance(color, dict) else color |
|
) |
|
|
|
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]] |
|
|
|
|
|
if isinstance(entity_annotation, LabeledSpan): |
|
annotation_text = entity_annotation.resolve()[1] |
|
elif isinstance(entity_annotation, LabeledMultiSpan): |
|
slice_idx = int(entity["data-slice-idx"]) |
|
annotation_text = entity_annotation.resolve()[1][slice_idx] |
|
else: |
|
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}") |
|
annotation_text_without_newline = annotation_text.replace("\n", "") |
|
|
|
if not entity.text.startswith(annotation_text_without_newline): |
|
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") |
|
|
|
entity["data-label"] = entity_annotation.label |
|
entity["data-relation-tails"] = json.dumps( |
|
[ |
|
{"entity-id": annotation2id[tail], "label": label} |
|
for tail, label in entity2tails.get(entity_annotation, []) |
|
if tail in annotation2id |
|
] |
|
) |
|
entity["data-relation-heads"] = json.dumps( |
|
[ |
|
{"entity-id": annotation2id[head], "label": label} |
|
for head, label in entity2heads.get(entity_annotation, []) |
|
if head in annotation2id |
|
] |
|
) |
|
|
|
|
|
return str(soup) |
|
|