|
import os |
|
import re |
|
import time |
|
from pathlib import Path |
|
|
|
import requests |
|
import streamlit as st |
|
from spacy import displacy |
|
from streamlit_extras.badges import badge |
|
from streamlit_extras.stylable_container import stylable_container |
|
|
|
|
|
|
|
import random |
|
|
|
from relik.inference.annotator import Relik |
|
|
|
|
|
def get_random_color(ents): |
|
colors = {} |
|
random_colors = generate_pastel_colors(len(ents)) |
|
for ent in ents: |
|
colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) |
|
return colors |
|
|
|
|
|
def floatrange(start, stop, steps): |
|
if int(steps) == 1: |
|
return [stop] |
|
return [ |
|
start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) |
|
] |
|
|
|
|
|
def hsl_to_rgb(h, s, l): |
|
def hue_2_rgb(v1, v2, v_h): |
|
while v_h < 0.0: |
|
v_h += 1.0 |
|
while v_h > 1.0: |
|
v_h -= 1.0 |
|
if 6 * v_h < 1.0: |
|
return v1 + (v2 - v1) * 6.0 * v_h |
|
if 2 * v_h < 1.0: |
|
return v2 |
|
if 3 * v_h < 2.0: |
|
return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 |
|
return v1 |
|
|
|
|
|
|
|
|
|
r, b, g = (l * 255,) * 3 |
|
if s != 0.0: |
|
if l < 0.5: |
|
var_2 = l * (1.0 + s) |
|
else: |
|
var_2 = (l + s) - (s * l) |
|
var_1 = 2.0 * l - var_2 |
|
r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) |
|
g = 255 * hue_2_rgb(var_1, var_2, h) |
|
b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) |
|
|
|
return int(round(r)), int(round(g)), int(round(b)) |
|
|
|
|
|
def generate_pastel_colors(n): |
|
"""Return different pastel colours. |
|
|
|
Input: |
|
n (integer) : The number of colors to return |
|
|
|
Output: |
|
A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) |
|
|
|
Example: |
|
>>> print generate_pastel_colors(5) |
|
['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] |
|
""" |
|
if n == 0: |
|
return [] |
|
|
|
|
|
start_hue = 0.6 |
|
saturation = 1.0 |
|
lightness = 0.8 |
|
|
|
|
|
|
|
return [ |
|
"#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) |
|
for hue in floatrange(start_hue, start_hue + 1, n + 1) |
|
][:-1] |
|
|
|
|
|
def set_sidebar(css): |
|
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>" |
|
with st.sidebar: |
|
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True) |
|
st.image( |
|
"http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg", |
|
use_column_width=True, |
|
) |
|
st.markdown("## ReLiK") |
|
st.write( |
|
f""" |
|
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")} |
|
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")} |
|
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")} |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown("## Sapienza NLP") |
|
st.write( |
|
f""" |
|
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")} |
|
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")} |
|
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")} |
|
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")} |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
def get_el_annotations(response): |
|
|
|
ents = [{"start": l.start, "end": l.end, "label": l.label} for l in response.labels] |
|
dict_of_ents = {"text": response.text, "ents": ents} |
|
label_in_text = set(l["label"] for l in dict_of_ents["ents"]) |
|
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} |
|
return dict_of_ents, options |
|
|
|
|
|
@st.cache_resource() |
|
def load_model(): |
|
return Relik( |
|
question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", |
|
document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered", |
|
reader="/home/user/app/models/relik-reader-aida-deberta-small", |
|
top_k=100, |
|
window_size=32, |
|
window_stride=16, |
|
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", |
|
) |
|
|
|
|
|
def set_intro(css): |
|
|
|
st.markdown("# ReLik") |
|
st.markdown( |
|
"### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
badge(type="github", name="sapienzanlp/relik") |
|
badge(type="pypi", name="relik") |
|
|
|
|
|
def run_client(): |
|
with open(Path(__file__).parent / "style.css") as f: |
|
css = f.read() |
|
|
|
st.set_page_config( |
|
page_title="ReLik", |
|
page_icon="🦮", |
|
layout="wide", |
|
) |
|
set_sidebar(css) |
|
set_intro(css) |
|
|
|
|
|
text = st.text_area( |
|
"Enter Text Below:", |
|
value="Michael Jordan was one of the best players in the NBA.", |
|
height=200, |
|
max_chars=500, |
|
) |
|
|
|
with stylable_container( |
|
key="annotate_button", |
|
css_styles=""" |
|
button { |
|
background-color: #802433; |
|
color: white; |
|
border-radius: 25px; |
|
} |
|
""", |
|
): |
|
submit = st.button("Annotate") |
|
|
|
|
|
if "relik_model" not in st.session_state.keys(): |
|
st.session_state["relik_model"] = load_model() |
|
relik_model = st.session_state["relik_model"] |
|
|
|
|
|
if submit: |
|
text = text.strip() |
|
if text: |
|
st.markdown("####") |
|
st.markdown("#### Entity Linking") |
|
with st.spinner(text="In progress"): |
|
response = relik_model(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dict_of_ents, options = get_el_annotations(response=response) |
|
display = displacy.render( |
|
dict_of_ents, manual=True, style="ent", options=options |
|
) |
|
display = display.replace("\n", " ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.container(): |
|
st.write(display, unsafe_allow_html=True) |
|
|
|
st.markdown("####") |
|
st.markdown("#### Relation Extraction") |
|
|
|
with st.container(): |
|
st.write("Coming :)", unsafe_allow_html=True) |
|
|
|
else: |
|
st.error("Please enter some text.") |
|
|
|
|
|
if __name__ == "__main__": |
|
run_client() |
|
|