|
""" |
|
The SpEL annotation visualization script. You can use this script as a playground to explore the capabilities and |
|
limitations of the SpEL framework. |
|
""" |
|
import torch |
|
from model import SpELAnnotator |
|
from data_loader import dl_sa |
|
from utils import chunk_annotate_and_merge_to_phrase |
|
from candidate_manager import CandidateManager |
|
import streamlit as st |
|
from annotated_text import annotated_text |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
load_aida_finetuned = True |
|
load_full_vocabulary=True |
|
candidate_setting = "n" |
|
model = SpELAnnotator() |
|
model.init_model_from_scratch(device=device) |
|
candidates_manager_to_use = CandidateManager(dl_sa.mentions_vocab, |
|
is_kb_yago=candidate_setting == "k", |
|
is_ppr_for_ned=candidate_setting.startswith("p"), |
|
is_context_agnostic=candidate_setting == "pg", |
|
is_indexed_for_spans=True) if candidate_setting != "n" else None |
|
if load_aida_finetuned and not load_full_vocabulary: |
|
model.shrink_classification_head_to_aida(device=device) |
|
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=3) |
|
elif load_aida_finetuned: |
|
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=4) |
|
else: |
|
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=2) |
|
return model, candidates_manager_to_use |
|
|
|
annotator, candidates_manager = load_model() |
|
st.title("SpEL Prediction Visualization") |
|
mention = st.text_input("Enter the text:") |
|
process_button = st.button("Annotate") |
|
|
|
if process_button and mention: |
|
phrase_annotations = chunk_annotate_and_merge_to_phrase( |
|
annotator, mention, k_for_top_k_to_keep=5, normalize_for_chinese_characters=True) |
|
last_step_annotations = [[p.words[0].token_offsets[0][1][0], |
|
p.words[-1].token_offsets[-1][1][-1], |
|
(dl_sa.mentions_itos[p.resolved_annotation], p.subword_annotations)] |
|
for p in phrase_annotations if p.resolved_annotation != 0] |
|
if candidates_manager: |
|
for p in phrase_annotations: |
|
candidates_manager.modify_phrase_annotation_using_candidates(p, mention) |
|
if last_step_annotations: |
|
anns = sorted([(l_ann[0], l_ann[1], l_ann[2][0]) for l_ann in last_step_annotations], key=lambda x: x[0]) |
|
begin = 0 |
|
last_char = len(mention) |
|
anns_pointer = 0 |
|
processed_anns = [] |
|
anno_text = [] |
|
while begin < last_char: |
|
if anns_pointer == len(anns): |
|
processed_anns.append((begin, last_char, "O")) |
|
anno_text.append(mention[begin: last_char]) |
|
begin = last_char |
|
continue |
|
first_unprocessed_annotation = anns[anns_pointer] |
|
if first_unprocessed_annotation[0] > begin: |
|
processed_anns.append((begin, first_unprocessed_annotation[0], "O")) |
|
anno_text.append(mention[begin: first_unprocessed_annotation[0]]) |
|
begin = first_unprocessed_annotation[0] |
|
else: |
|
processed_anns.append(first_unprocessed_annotation) |
|
anns_pointer += 1 |
|
begin = first_unprocessed_annotation[1] |
|
anno_text.append((mention[first_unprocessed_annotation[0]: first_unprocessed_annotation[1]], first_unprocessed_annotation[2])) |
|
annotated_text(anno_text) |
|
else: |
|
annotated_text(mention) |
|
|