SpEL / app.py
shavarani's picture
app.py updated
96a2369
"""
The SpEL annotation visualization script. You can use this script as a playground to explore the capabilities and
limitations of the SpEL framework.
"""
import torch
from model import SpELAnnotator
from data_loader import dl_sa
from utils import chunk_annotate_and_merge_to_phrase
from candidate_manager import CandidateManager
import streamlit as st
from annotated_text import annotated_text
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@st.cache_resource
def load_model():
load_aida_finetuned = True
load_full_vocabulary=True
candidate_setting = "n"
model = SpELAnnotator()
model.init_model_from_scratch(device=device)
candidates_manager_to_use = CandidateManager(dl_sa.mentions_vocab,
is_kb_yago=candidate_setting == "k",
is_ppr_for_ned=candidate_setting.startswith("p"),
is_context_agnostic=candidate_setting == "pg",
is_indexed_for_spans=True) if candidate_setting != "n" else None
if load_aida_finetuned and not load_full_vocabulary:
model.shrink_classification_head_to_aida(device=device)
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=3)
elif load_aida_finetuned:
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=4)
else:
model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=2)
return model, candidates_manager_to_use
annotator, candidates_manager = load_model()
st.title("SpEL Prediction Visualization")
st.caption('Running the \"[SpEL-base-step3-500K.pt](https://vault.sfu.ca/index.php/s/8nw5fFXdz2yBP5z/download)\" model without any hand-crafted candidate sets. For more information please checkout [SpEL\'s github repository](https://github.com/shavarani/SpEL).')
mention = st.text_input("Enter the text:")
process_button = st.button("Annotate")
if process_button and mention:
phrase_annotations = chunk_annotate_and_merge_to_phrase(
annotator, mention, k_for_top_k_to_keep=5, normalize_for_chinese_characters=True)
last_step_annotations = [[p.words[0].token_offsets[0][1][0],
p.words[-1].token_offsets[-1][1][-1],
(dl_sa.mentions_itos[p.resolved_annotation], p.subword_annotations)]
for p in phrase_annotations if p.resolved_annotation != 0]
if candidates_manager:
for p in phrase_annotations:
candidates_manager.modify_phrase_annotation_using_candidates(p, mention)
if last_step_annotations:
anns = sorted([(l_ann[0], l_ann[1], l_ann[2][0]) for l_ann in last_step_annotations], key=lambda x: x[0])
begin = 0
last_char = len(mention)
anns_pointer = 0
processed_anns = []
anno_text = []
while begin < last_char:
if anns_pointer == len(anns):
processed_anns.append((begin, last_char, "O"))
anno_text.append(mention[begin: last_char])
begin = last_char
continue
first_unprocessed_annotation = anns[anns_pointer]
if first_unprocessed_annotation[0] > begin:
processed_anns.append((begin, first_unprocessed_annotation[0], "O"))
anno_text.append(mention[begin: first_unprocessed_annotation[0]])
begin = first_unprocessed_annotation[0]
else:
processed_anns.append(first_unprocessed_annotation)
anns_pointer += 1
begin = first_unprocessed_annotation[1]
anno_text.append((mention[first_unprocessed_annotation[0]: first_unprocessed_annotation[1]], first_unprocessed_annotation[2]))
annotated_text(anno_text)
else:
annotated_text(mention)