import streamlit as st from annotated_text import annotated_text from refined.inference.processor import Refined # Sidebar st.sidebar.image("logo-wordlift.png") # Initiate the model model_options = {"aida_model", "wikipedia_model_with_numbers"} selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options)) @st.cache_resource # 👈 Add the caching decorator def load_model(model_name): # Load the pretrained model refined_model = Refined.from_pretrained(model_name=model_name, entity_set="wikipedia") return refined_model # Use the cached model refined_model = load_model(selected_model_name) # Helper functions def get_wikidata_id(entity_string): entity_list = entity_string.split("=") return "https://www.wikidata.org/wiki/" + str(entity_list[1]) # Create the form with st.form(key='my_form'): text_input = st.text_input(label='Enter a sentence') submit_button = st.form_submit_button(label='Submit') # Process the text and extract the entities if text_input: entities = refined_model.process_text(text_input) entities_map = {} entities_link_descriptions = {} for entity in entities: single_entity_list = str(entity).strip('][').replace("\'", "").split(', ') if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: entities_map[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[0].strip() entities_link_descriptions[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "") combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map]) def get_entity_description(entity_link, combined_entity_info_dictionary): return combined_entity_info_dictionary[entity_link][1] if submit_button: # Prepare a list to hold the final output final_text = [] # Replace each entity in the text with its annotated version for wikidata_link, entity in entities_map.items(): description = get_entity_description(wikidata_link, combined_entity_info_dictionary) entity_annotation = (entity, description, "#8ef") text_input = text_input.replace(entity, f'{{{str(entity_annotation)}}}', 1) # Split the modified text_input into a list text_list = text_input.split("{") for item in text_list: if "}" in item: item_list = item.split("}") final_text.append(eval(item_list[0])) if len(item_list[1]) > 0: final_text.append(item_list[1]) else: final_text.append(item) # Pass the final_text to the annotated_text function annotated_text(*final_text)