Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,12 @@ import re
|
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import requests
|
7 |
import streamlit as st
|
8 |
from spacy import displacy
|
@@ -14,7 +20,13 @@ from streamlit_extras.stylable_container import stylable_container
|
|
14 |
import random
|
15 |
|
16 |
from relik.inference.annotator import Relik
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def get_random_color(ents):
|
20 |
colors = {}
|
@@ -93,44 +105,31 @@ def generate_pastel_colors(n):
|
|
93 |
|
94 |
|
95 |
def set_sidebar(css):
|
96 |
-
white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
|
97 |
with st.sidebar:
|
98 |
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
99 |
st.image(
|
100 |
-
"
|
101 |
use_column_width=True,
|
102 |
)
|
103 |
-
st.markdown("
|
104 |
-
st.
|
105 |
-
f"""
|
106 |
-
- {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i> Paper")}
|
107 |
-
- {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i> GitHub")}
|
108 |
-
- {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i> Docker Hub")}
|
109 |
-
""",
|
110 |
-
unsafe_allow_html=True,
|
111 |
-
)
|
112 |
-
st.markdown("## Sapienza NLP")
|
113 |
-
st.write(
|
114 |
-
f"""
|
115 |
-
- {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i> Webpage")}
|
116 |
-
- {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i> GitHub")}
|
117 |
-
- {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i> Twitter")}
|
118 |
-
- {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i> LinkedIn")}
|
119 |
-
""",
|
120 |
-
unsafe_allow_html=True,
|
121 |
-
)
|
122 |
-
|
123 |
|
124 |
def get_el_annotations(response):
|
125 |
-
|
|
|
126 |
# swap labels key with ents
|
127 |
ents = [
|
128 |
{
|
129 |
"start": l.start,
|
130 |
"end": l.end,
|
131 |
-
"label":
|
|
|
|
|
|
|
|
|
|
|
132 |
}
|
133 |
-
for l in response.
|
134 |
]
|
135 |
dict_of_ents = {"text": response.text, "ents": ents}
|
136 |
label_in_text = set(l["label"] for l in dict_of_ents["ents"])
|
@@ -138,24 +137,58 @@ def get_el_annotations(response):
|
|
138 |
return dict_of_ents, options
|
139 |
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
@st.cache_resource()
|
142 |
def load_model():
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
|
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def set_intro(css):
|
155 |
# intro
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
st.markdown(
|
158 |
-
"###
|
159 |
)
|
160 |
# st.markdown(
|
161 |
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
@@ -163,16 +196,13 @@ def set_intro(css):
|
|
163 |
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
164 |
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
165 |
# )
|
166 |
-
badge(type="github", name="sapienzanlp/relik")
|
167 |
-
badge(type="pypi", name="relik")
|
168 |
-
|
169 |
|
170 |
def run_client():
|
171 |
with open(Path(__file__).parent / "style.css") as f:
|
172 |
css = f.read()
|
173 |
|
174 |
st.set_page_config(
|
175 |
-
page_title="
|
176 |
page_icon="🦮",
|
177 |
layout="wide",
|
178 |
)
|
@@ -182,7 +212,7 @@ def run_client():
|
|
182 |
# text input
|
183 |
text = st.text_area(
|
184 |
"Enter Text Below:",
|
185 |
-
value="
|
186 |
height=200,
|
187 |
max_chars=1500,
|
188 |
)
|
@@ -191,8 +221,8 @@ def run_client():
|
|
191 |
key="annotate_button",
|
192 |
css_styles="""
|
193 |
button {
|
194 |
-
background-color: #
|
195 |
-
color:
|
196 |
border-radius: 25px;
|
197 |
}
|
198 |
""",
|
@@ -212,6 +242,7 @@ def run_client():
|
|
212 |
st.markdown("#### Entity Linking")
|
213 |
with st.spinner(text="In progress"):
|
214 |
response = relik_model(text)
|
|
|
215 |
# response = requests.post(RELIK, json=text)
|
216 |
# if response.status_code != 200:
|
217 |
# st.error("Error: {}".format(response.status_code))
|
@@ -220,14 +251,23 @@ def run_client():
|
|
220 |
|
221 |
# st.markdown("##")
|
222 |
dict_of_ents, options = get_el_annotations(response=response)
|
|
|
|
|
223 |
display = displacy.render(
|
224 |
dict_of_ents, manual=True, style="ent", options=options
|
225 |
)
|
|
|
|
|
226 |
display = display.replace("\n", " ")
|
|
|
227 |
# heurstic, prevents split of annotation decorations
|
228 |
display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
|
|
|
229 |
with st.container():
|
230 |
st.write(display, unsafe_allow_html=True)
|
|
|
|
|
|
|
231 |
|
232 |
else:
|
233 |
st.error("Please enter some text.")
|
|
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
|
6 |
+
from relik.retriever import GoldenRetriever
|
7 |
+
|
8 |
+
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
|
9 |
+
from relik.retriever.indexers.document import DocumentStore
|
10 |
+
from relik.retriever import GoldenRetriever
|
11 |
+
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
|
12 |
import requests
|
13 |
import streamlit as st
|
14 |
from spacy import displacy
|
|
|
20 |
import random
|
21 |
|
22 |
from relik.inference.annotator import Relik
|
23 |
+
from relik.inference.data.objects import (
|
24 |
+
AnnotationType,
|
25 |
+
RelikOutput,
|
26 |
+
Span,
|
27 |
+
TaskType,
|
28 |
+
Triples,
|
29 |
+
)
|
30 |
|
31 |
def get_random_color(ents):
|
32 |
colors = {}
|
|
|
105 |
|
106 |
|
107 |
def set_sidebar(css):
|
|
|
108 |
with st.sidebar:
|
109 |
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
|
110 |
st.image(
|
111 |
+
"https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg",
|
112 |
use_column_width=True,
|
113 |
)
|
114 |
+
st.markdown("### World Bank")
|
115 |
+
st.markdown("### DIME")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
def get_el_annotations(response):
|
118 |
+
i_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-intervention/{}' style='color: #414141'> <span style='font-size: 1.0em; font-family: monospace'> Intervention {}</span></a>"
|
119 |
+
o_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-outcome/{}' style='color: #414141'><span style='font-size: 1.0em; font-family: monospace'> Outcome: {}</span></a>"
|
120 |
# swap labels key with ents
|
121 |
ents = [
|
122 |
{
|
123 |
"start": l.start,
|
124 |
"end": l.end,
|
125 |
+
"label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
|
126 |
+
} if io_map[l.label] == "intervention" else
|
127 |
+
{
|
128 |
+
"start": l.start,
|
129 |
+
"end": l.end,
|
130 |
+
"label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
|
131 |
}
|
132 |
+
for l in response.spans
|
133 |
]
|
134 |
dict_of_ents = {"text": response.text, "ents": ents}
|
135 |
label_in_text = set(l["label"] for l in dict_of_ents["ents"])
|
|
|
137 |
return dict_of_ents, options
|
138 |
|
139 |
|
140 |
+
|
141 |
+
def get_retriever_annotations(response):
|
142 |
+
el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>"
|
143 |
+
# swap labels key with ents
|
144 |
+
ents = [l.text
|
145 |
+
for l in response.candidates[TaskType.SPAN]
|
146 |
+
]
|
147 |
+
dict_of_ents = {"text": response.text, "ents": ents}
|
148 |
+
label_in_text = set(l for l in dict_of_ents["ents"])
|
149 |
+
options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
|
150 |
+
return dict_of_ents, options
|
151 |
+
import json
|
152 |
+
io_map = {}
|
153 |
+
with open("/home/user/app/model/retriever/document_index/documents.jsonl", "r") as r:
|
154 |
+
for line in r:
|
155 |
+
element = json.loads(line)
|
156 |
+
io_map[element["text"]] = element["metadata"]["type"]
|
157 |
+
|
158 |
@st.cache_resource()
|
159 |
def load_model():
|
160 |
+
|
161 |
+
retriever = GoldenRetriever(
|
162 |
+
question_encoder="/home/user/app/model/retriever/question_encoder",
|
163 |
+
document_index=InMemoryDocumentIndex(
|
164 |
+
documents=DocumentStore.from_file(
|
165 |
+
"/home/user/app/model/retriever/document_index/documents.jsonl"
|
166 |
+
),
|
167 |
+
metadata_fields=["definition"],
|
168 |
+
separator=' <def> ',
|
169 |
+
device="cuda"
|
170 |
+
),
|
171 |
+
devide="cuda"
|
172 |
+
|
173 |
)
|
174 |
+
retriever.index()
|
175 |
|
176 |
+
reader = RelikReaderForSpanExtraction("/home/user/app/model/reader",
|
177 |
+
dataset_kwargs={"use_nme": True})
|
178 |
+
|
179 |
+
relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
|
180 |
+
|
181 |
+
return relik
|
182 |
|
183 |
def set_intro(css):
|
184 |
# intro
|
185 |
+
|
186 |
+
st.markdown("# CausalAI")
|
187 |
+
st.image(
|
188 |
+
"http://35.237.102.64/public/logo.png",
|
189 |
+
)
|
190 |
st.markdown(
|
191 |
+
"### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking"
|
192 |
)
|
193 |
# st.markdown(
|
194 |
# "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
|
|
|
196 |
# "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
|
197 |
# "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
|
198 |
# )
|
|
|
|
|
|
|
199 |
|
200 |
def run_client():
|
201 |
with open(Path(__file__).parent / "style.css") as f:
|
202 |
css = f.read()
|
203 |
|
204 |
st.set_page_config(
|
205 |
+
page_title="CausalAI",
|
206 |
page_icon="🦮",
|
207 |
layout="wide",
|
208 |
)
|
|
|
212 |
# text input
|
213 |
text = st.text_area(
|
214 |
"Enter Text Below:",
|
215 |
+
value="How does unconditional cash transver affect to reduce poverty?",
|
216 |
height=200,
|
217 |
max_chars=1500,
|
218 |
)
|
|
|
221 |
key="annotate_button",
|
222 |
css_styles="""
|
223 |
button {
|
224 |
+
background-color: #a8ebff;
|
225 |
+
color: black;
|
226 |
border-radius: 25px;
|
227 |
}
|
228 |
""",
|
|
|
242 |
st.markdown("#### Entity Linking")
|
243 |
with st.spinner(text="In progress"):
|
244 |
response = relik_model(text)
|
245 |
+
|
246 |
# response = requests.post(RELIK, json=text)
|
247 |
# if response.status_code != 200:
|
248 |
# st.error("Error: {}".format(response.status_code))
|
|
|
251 |
|
252 |
# st.markdown("##")
|
253 |
dict_of_ents, options = get_el_annotations(response=response)
|
254 |
+
dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response)
|
255 |
+
|
256 |
display = displacy.render(
|
257 |
dict_of_ents, manual=True, style="ent", options=options
|
258 |
)
|
259 |
+
|
260 |
+
|
261 |
display = display.replace("\n", " ")
|
262 |
+
|
263 |
# heurstic, prevents split of annotation decorations
|
264 |
display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
|
265 |
+
|
266 |
with st.container():
|
267 |
st.write(display, unsafe_allow_html=True)
|
268 |
+
|
269 |
+
text = "## Possible Candidates:\n- " + "\n- ".join([candidate for candidate in dict_of_ents_candidates["ents"][:5]])
|
270 |
+
st.markdown(text)
|
271 |
|
272 |
else:
|
273 |
st.error("Please enter some text.")
|