import json import logging from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence, Union from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan from .rendering_utils_displacy import EntityRenderer logger = logging.getLogger(__name__) RENDER_WITH_DISPLACY = "displacy" RENDER_WITH_PRETTY_TABLE = "pretty_table" AVAILABLE_RENDER_MODES = [RENDER_WITH_DISPLACY, RENDER_WITH_PRETTY_TABLE] # adjusted from rendering_utils_displacy.TPL_ENT TPL_ENT_WITH_ID = """ {text} {label} """ HIGHLIGHT_SPANS_JS = """ () => { function maybeSetColor(entity, colorAttributeKey, colorDictKey) { var color = entity.getAttribute('data-color-' + colorAttributeKey); // if color is a json string, parse it and use the value at colorDictKey try { const colors = JSON.parse(color); color = colors[colorDictKey]; } catch (e) {} if (color) { //const highlightMode = entity.getAttribute('data-highlight-mode'); //if (highlightMode === 'fill') { entity.style.backgroundColor = color; entity.style.color = '#000'; //} entity.style.borderColor = color; } } function highlightRelationArguments(entityId) { const entities = document.querySelectorAll('.entity'); // reset all entities entities.forEach(entity => { const color = entity.getAttribute('data-color-original'); entity.style.backgroundColor = color; const borderColor = entity.getAttribute('data-border-color-original'); entity.style.borderColor = borderColor; entity.style.color = ''; }); if (entityId !== null) { var visitedEntities = new Set(); // highlight selected entity // get all elements with attribute data-entity-id==entityId const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`); selectedEntityParts.forEach(selectedEntityPart => { const label = selectedEntityPart.getAttribute('data-label'); maybeSetColor(selectedEntityPart, 'selected', label); visitedEntities.add(selectedEntityPart); }); // <-- Corrected closing parenthesis here // if there is at least one part, get the first one and ... if (selectedEntityParts.length > 0) { const selectedEntity = selectedEntityParts[0]; // ... highlight tails and ... const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails')); relationTailsAndLabels.forEach(relationTail => { const tailEntityId = relationTail['entity-id']; const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`); tailEntityParts.forEach(tailEntity => { const label = relationTail['label']; maybeSetColor(tailEntity, 'tail', label); visitedEntities.add(tailEntity); }); // <-- Corrected closing parenthesis here }); // <-- Corrected closing parenthesis here // .. highlight heads const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads')); relationHeadsAndLabels.forEach(relationHead => { const headEntityId = relationHead['entity-id']; const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`); headEntityParts.forEach(headEntity => { const label = relationHead['label']; maybeSetColor(headEntity, 'head', label); visitedEntities.add(headEntity); }); // <-- Corrected closing parenthesis here }); // <-- Corrected closing parenthesis here } // highlight other entities entities.forEach(entity => { if (!visitedEntities.has(entity)) { const label = entity.getAttribute('data-label'); maybeSetColor(entity, 'other', label); } }); } } function setHoverAduId(entityId) { // get the textarea element that holds the reference adu id let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); // set the value of the input field hoverAduIdDiv.value = entityId; // trigger an input event to update the state var event = new Event('input'); hoverAduIdDiv.dispatchEvent(event); } function setReferenceAduIdFromHover() { // get the hover adu id const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea'); // get the value of the input field const entityId = hoverAduIdDiv.value; // get the textarea element that holds the reference adu id let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea'); // set the value of the input field referenceAduIdDiv.value = entityId; // trigger an input event to update the state var event = new Event('input'); referenceAduIdDiv.dispatchEvent(event); } const entities = document.querySelectorAll('.entity'); entities.forEach(entity => { // make the cursor a pointer entity.style.cursor = 'pointer'; const alreadyHasListener = entity.getAttribute('data-has-listener'); if (alreadyHasListener) { return; } entity.addEventListener('mouseover', () => { const entityId = entity.getAttribute('data-entity-id'); highlightRelationArguments(entityId); setHoverAduId(entityId); }); entity.addEventListener('mouseout', () => { highlightRelationArguments(null); }); entity.setAttribute('data-has-listener', 'true'); }); const entityContainer = document.querySelector('.entities'); if (entityContainer) { entityContainer.addEventListener('click', () => { setReferenceAduIdFromHover(); }); // make the cursor a pointer // entityContainer.style.cursor = 'pointer'; } } """ def render_pretty_table( text: str, spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], span_id2idx: Dict[str, int], binary_relations: Sequence[BinaryRelation], **render_kwargs, ): from prettytable import PrettyTable t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in list(binary_relations) + list(binary_relations): t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = "
" + html + "
" return html def render_displacy( text: str, spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], span_id2idx: Dict[str, int], binary_relations: Sequence[BinaryRelation], highlight_span_ids: Optional[List[str]] = None, inject_relations=True, colors_hover=None, entity_options={}, **render_kwargs, ): ents: List[Dict[str, Any]] = [] for entity_id, idx in span_id2idx.items(): labeled_span = spans[idx] highlight_mode = ( "fill" if highlight_span_ids is None or entity_id in highlight_span_ids else "border" ) # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations # on hover and to inject the relation data. if isinstance(labeled_span, LabeledSpan): ents.append( { "start": labeled_span.start, "end": labeled_span.end, "label": labeled_span.label, "params": { "entity_id": entity_id, "slice_idx": 0, "highlight_mode": highlight_mode, }, } ) elif isinstance(labeled_span, LabeledMultiSpan): for i, (start, end) in enumerate(labeled_span.slices): ents.append( { "start": start, "end": end, "label": labeled_span.label, "params": { "entity_id": entity_id, "slice_idx": i, "highlight_mode": highlight_mode, }, } ) else: raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}") ents_sorted = sorted(ents, key=lambda x: (x["start"], x["end"])) spacy_doc = { "text": text, # the ents MUST be sorted by start and end "ents": ents_sorted, "title": None, } # copy to avoid modifying the original options entity_options = entity_options.copy() # use the custom template with the entity ID entity_options["template"] = TPL_ENT_WITH_ID renderer = EntityRenderer(options=entity_options) html = renderer.render([spacy_doc], page=True, minify=True).strip() html = "
" + html + "
" if inject_relations: html = inject_relation_data( html, spans=spans, span_id2idx=span_id2idx, binary_relations=binary_relations, additional_colors=colors_hover, ) return html def inject_relation_data( html: str, spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]], span_id2idx: Dict[str, int], binary_relations: Sequence[BinaryRelation], additional_colors: Optional[Dict[str, Union[str, dict]]] = None, ) -> str: from bs4 import BeautifulSoup # Parse the HTML using BeautifulSoup soup = BeautifulSoup(html, "html.parser") entity2tails = defaultdict(list) entity2heads = defaultdict(list) for relation in binary_relations: entity2heads[relation.tail].append((relation.head, relation.label)) entity2tails[relation.head].append((relation.tail, relation.label)) annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()} # Add unique IDs to each entity entities = soup.find_all(class_="entity") for entity in entities: original_color = entity["style"].split("background:")[1].split(";")[0].strip() original_border_color = entity["style"].split("border-color:")[1].split(";")[0].strip() entity["data-color-original"] = original_color entity["data-border-color-original"] = original_border_color if additional_colors is not None: for key, color in additional_colors.items(): entity[f"data-color-{key}"] = ( json.dumps(color) if isinstance(color, dict) else color ) entity_annotation = spans[span_id2idx[entity["data-entity-id"]]] # sanity check. if isinstance(entity_annotation, LabeledSpan): annotation_text = entity_annotation.resolve()[1] elif isinstance(entity_annotation, LabeledMultiSpan): slice_idx = int(entity["data-slice-idx"]) annotation_text = entity_annotation.resolve()[1][slice_idx] else: raise ValueError(f"Unsupported entity type: {type(entity_annotation)}") annotation_text_without_newline = annotation_text.replace("\n", "") # Just check the start, because the text has the label attached to the end if not entity.text.startswith(annotation_text_without_newline): logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") entity["data-label"] = entity_annotation.label entity["data-relation-tails"] = json.dumps( [ {"entity-id": annotation2id[tail], "label": label} for tail, label in entity2tails.get(entity_annotation, []) if tail in annotation2id ] ) entity["data-relation-heads"] = json.dumps( [ {"entity-id": annotation2id[head], "label": label} for head, label in entity2heads.get(entity_annotation, []) if head in annotation2id ] ) # Return the modified HTML as a string return str(soup)