File size: 2,823 Bytes
bbcf937
59c3f8c
bbcf937
 
0bec8b3
542aecd
bbcf937
 
 
dedd775
bbcf937
dedd775
 
 
 
 
 
320ee5a
dedd775
bbcf937
320ee5a
 
 
 
 
bbcf937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542aecd
 
 
5cb9d08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542aecd
5cb9d08
542aecd
 
5cb9d08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import streamlit as st
from annotated_text import annotated_text
from refined.inference.processor import Refined

# Sidebar
st.sidebar.image("logo-wordlift.png")

# Initiate the model
model_options = {"aida_model", "wikipedia_model_with_numbers"}
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options))

@st.cache_resource  # 👈 Add the caching decorator
def load_model(model_name):
    # Load the pretrained model
    refined_model = Refined.from_pretrained(model_name=model_name, entity_set="wikipedia")
    return refined_model

# Use the cached model
refined_model = load_model(selected_model_name)

# Helper functions
def get_wikidata_id(entity_string):
  entity_list = entity_string.split("=")
  return  "https://www.wikidata.org/wiki/" + str(entity_list[1])

# Create the form
with st.form(key='my_form'):
    text_input = st.text_input(label='Enter a sentence')
    submit_button = st.form_submit_button(label='Submit')

# Process the text and extract the entities
if text_input:
    entities = refined_model.process_text(text_input)

    entities_map = {}
    entities_link_descriptions = {}
    for entity in entities:
        single_entity_list = str(entity).strip('][').replace("\'", "").split(', ')
        if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: 
            entities_map[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[0].strip()
            entities_link_descriptions[get_wikidata_id(single_entity_list[1]).strip()] = single_entity_list[2].strip().replace("(", "").replace(")", "")

    combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_link_descriptions[k]]) for k in entities_map])

    def get_entity_description(entity_link, combined_entity_info_dictionary):
        return combined_entity_info_dictionary[entity_link][1]

    if submit_button:
        # Prepare a list to hold the final output
        final_text = []

        # Replace each entity in the text with its annotated version
        for wikidata_link, entity in entities_map.items():
            description = get_entity_description(wikidata_link, combined_entity_info_dictionary)
            entity_annotation = (entity, description, "#8ef")
            text_input = text_input.replace(entity, f'{{{str(entity_annotation)}}}', 1)

        # Split the modified text_input into a list
        text_list = text_input.split("{")

        for item in text_list:
            if "}" in item:
                item_list = item.split("}")
                final_text.append(eval(item_list[0]))
                if len(item_list[1]) > 0:
                    final_text.append(item_list[1])
            else:
                final_text.append(item)

        # Pass the final_text to the annotated_text function
        annotated_text(*final_text)