import json import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence, Union from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan from .rendering_utils_displacy import EntityRenderer logger = logging.getLogger(__name__) RENDER_WITH_DISPLACY = "displacy" RENDER_WITH_PRETTY_TABLE = "pretty_table" AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE] # adjusted from rendering_utils_displacy.TPL_ENT TPL_ENT_WITH_ID = """ {text} {label} """ HIGHLIGHT_SPANS_JS = """ () => { function maybeSetColor(entity, colorAttributeKey, colorDictKey) { var color = entity.getAttribute('data-color-' + colorAttributeKey); // if color is a json string, parse it and use the value at colorDictKey try { const colors = JSON.parse(color); color = colors[colorDictKey]; } catch (e) {} if (color) { //const highlightMode = entity.getAttribute('data-highlight-mode'); //if (highlightMode === 'fill') { entity.style.backgroundColor = color; entity.style.color = '#000'; //} entity.style.borderColor = color; } } function highlightRelationArguments(entityId) { const entities = document.querySelectorAll('.entity'); // reset all entities entities.forEach(entity => { const color = entity.getAttribute('data-color-original'); entity.style.backgroundColor = color; const borderColor = entity.getAttribute('data-border-color-original'); entity.style.borderColor = borderColor; entity.style.color = ''; }); if (entityId !== null) { var visitedEntities = new Set(); // highlight selected entity // get all elements with attribute data-entity-id==entityId const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`); selectedEntityParts.forEach(selectedEntityPart => { const label = selectedEntityPart.getAttribute('data-label'); maybeSetColor(selectedEntityPart, 'selected', label); visitedEntities.add(selectedEntityPart); }); // <-- Corrected closing parenthesis here // if there is at least one part, get the first one and ... if (selectedEntityParts.length > 0) { const selectedEntity = selectedEntityParts[0]; // ... highlight tails and ... const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails')); relationTailsAndLabels.forEach(relationTail => { const tailEntityId = relationTail['entity-id']; const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`); tailEntityParts.forEach(tailEntity => { const label = relationTail['label']; maybeSetColor(tailEntity, 'tail', label); visitedEntities.add(tailEntity); }); // <-- Corrected closing parenthesis here }); // <-- Corrected closing parenthesis here // .. highlight heads const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads')); relationHeadsAndLabels.forEach(relationHead => { const headEntityId = relationHead['entity-id']; const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`); headEntityParts.forEach(headEntity => { const label = relationHead['label']; maybeSetColor(headEntity, 'head', label); visitedEntities.add(headEntity); }); // <-- Corrected closing parenthesis here }); // <-- Corrected closing parenthesis here } // highlight other entities entities.forEach(entity => { if (!visitedEntities.has(entity)) { const label = entity.getAttribute('data-label'); maybeSetColor(entity, 'other', label); } }); } } function setHoverAduId(entityId) { // get the textarea element that holds the reference adu id let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); // set the value of the input field hoverAduIdDiv.value = entityId; // trigger an input event to update the state var event = new Event('input'); hoverAduIdDiv.dispatchEvent(event); } function setReferenceAduIdFromHover() { // get the hover adu id const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); // get the value of the input field const entityId = hoverAduIdDiv.value; // get the textarea element that holds the reference adu id let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea'); // set the value of the input field referenceAduIdDiv.value = entityId; // trigger an input event to update the state var event = new Event('input'); referenceAduIdDiv.dispatchEvent(event); } const entities = document.querySelectorAll('.entity'); entities.forEach(entity => { // make the cursor a pointer entity.style.cursor = 'pointer'; const alreadyHasListener = entity.getAttribute('data-has-listener'); if (alreadyHasListener) { return; } entity.addEventListener('mouseover', () => { const entityId = entity.getAttribute('data-entity-id'); highlightRelationArguments(entityId); setHoverAduId(entityId); }); entity.addEventListener('mouseout', () => { highlightRelationArguments(null); }); entity.setAttribute('data-has-listener', 'true'); }); const entityContainer = document.querySelector('.entities'); if (entityContainer) { entityContainer.addEventListener('click', () => { setReferenceAduIdFromHover(); }); // make the cursor a pointer // entityContainer.style.cursor = 'pointer'; } } """ def render_pretty_table( text: str, spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], span_id2idx: Dict[str, int], binary_relations: Sequence[BinaryRelation], **render_kwargs, ): from prettytable import PrettyTable t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in list(binary_relations) + list(binary_relations): t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = "