CarlosMalaga commited on
Commit
9ee75f8
1 Parent(s): 9da9a9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -43
app.py CHANGED
@@ -3,6 +3,12 @@ import re
3
  import time
4
  from pathlib import Path
5
 
 
 
 
 
 
 
6
  import requests
7
  import streamlit as st
8
  from spacy import displacy
@@ -14,7 +20,13 @@ from streamlit_extras.stylable_container import stylable_container
14
  import random
15
 
16
  from relik.inference.annotator import Relik
17
-
 
 
 
 
 
 
18
 
19
  def get_random_color(ents):
20
  colors = {}
@@ -93,44 +105,31 @@ def generate_pastel_colors(n):
93
 
94
 
95
  def set_sidebar(css):
96
- white_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='{}'>{}</a>"
97
  with st.sidebar:
98
  st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
99
  st.image(
100
- "http://nlp.uniroma1.it/static/website/sapienza-nlp-logo-wh.svg",
101
  use_column_width=True,
102
  )
103
- st.markdown("## ReLiK")
104
- st.write(
105
- f"""
106
- - {white_link_wrapper.format("#", "<i class='fa-solid fa-file'></i>&nbsp; Paper")}
107
- - {white_link_wrapper.format("https://github.com/SapienzaNLP/relik", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
108
- - {white_link_wrapper.format("https://hub.docker.com/repository/docker/sapienzanlp/relik", "<i class='fa-brands fa-docker'></i>&nbsp; Docker Hub")}
109
- """,
110
- unsafe_allow_html=True,
111
- )
112
- st.markdown("## Sapienza NLP")
113
- st.write(
114
- f"""
115
- - {white_link_wrapper.format("https://nlp.uniroma1.it", "<i class='fa-solid fa-globe'></i>&nbsp; Webpage")}
116
- - {white_link_wrapper.format("https://github.com/SapienzaNLP", "<i class='fa-brands fa-github'></i>&nbsp; GitHub")}
117
- - {white_link_wrapper.format("https://twitter.com/SapienzaNLP", "<i class='fa-brands fa-twitter'></i>&nbsp; Twitter")}
118
- - {white_link_wrapper.format("https://www.linkedin.com/company/79434450", "<i class='fa-brands fa-linkedin'></i>&nbsp; LinkedIn")}
119
- """,
120
- unsafe_allow_html=True,
121
- )
122
-
123
 
124
  def get_el_annotations(response):
125
- el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>"
 
126
  # swap labels key with ents
127
  ents = [
128
  {
129
  "start": l.start,
130
  "end": l.end,
131
- "label": el_link_wrapper.format(l.label.replace(" ", "_"), l.label),
 
 
 
 
 
132
  }
133
- for l in response.labels
134
  ]
135
  dict_of_ents = {"text": response.text, "ents": ents}
136
  label_in_text = set(l["label"] for l in dict_of_ents["ents"])
@@ -138,24 +137,58 @@ def get_el_annotations(response):
138
  return dict_of_ents, options
139
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  @st.cache_resource()
142
  def load_model():
143
- return Relik(
144
- question_encoder="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
145
- document_index="/home/user/app/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
146
- reader="/home/user/app/models/relik-reader-aida-deberta-small",
147
- top_k=100,
148
- window_size=32,
149
- window_stride=16,
150
- candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing",
 
 
 
 
 
151
  )
 
152
 
 
 
 
 
 
 
153
 
154
  def set_intro(css):
155
  # intro
156
- st.markdown("# ReLik")
 
 
 
 
157
  st.markdown(
158
- "### Retrieve, Read and LinK: Fast and Accurate Entity Linking and Relation Extraction on an Academic Budget"
159
  )
160
  # st.markdown(
161
  # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
@@ -163,16 +196,13 @@ def set_intro(css):
163
  # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
164
  # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
165
  # )
166
- badge(type="github", name="sapienzanlp/relik")
167
- badge(type="pypi", name="relik")
168
-
169
 
170
  def run_client():
171
  with open(Path(__file__).parent / "style.css") as f:
172
  css = f.read()
173
 
174
  st.set_page_config(
175
- page_title="ReLik",
176
  page_icon="🦮",
177
  layout="wide",
178
  )
@@ -182,7 +212,7 @@ def run_client():
182
  # text input
183
  text = st.text_area(
184
  "Enter Text Below:",
185
- value="Michael Jordan was one of the best players in the NBA.",
186
  height=200,
187
  max_chars=1500,
188
  )
@@ -191,8 +221,8 @@ def run_client():
191
  key="annotate_button",
192
  css_styles="""
193
  button {
194
- background-color: #802433;
195
- color: white;
196
  border-radius: 25px;
197
  }
198
  """,
@@ -212,6 +242,7 @@ def run_client():
212
  st.markdown("#### Entity Linking")
213
  with st.spinner(text="In progress"):
214
  response = relik_model(text)
 
215
  # response = requests.post(RELIK, json=text)
216
  # if response.status_code != 200:
217
  # st.error("Error: {}".format(response.status_code))
@@ -220,14 +251,23 @@ def run_client():
220
 
221
  # st.markdown("##")
222
  dict_of_ents, options = get_el_annotations(response=response)
 
 
223
  display = displacy.render(
224
  dict_of_ents, manual=True, style="ent", options=options
225
  )
 
 
226
  display = display.replace("\n", " ")
 
227
  # heurstic, prevents split of annotation decorations
228
  display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
 
229
  with st.container():
230
  st.write(display, unsafe_allow_html=True)
 
 
 
231
 
232
  else:
233
  st.error("Please enter some text.")
 
3
  import time
4
  from pathlib import Path
5
 
6
+ from relik.retriever import GoldenRetriever
7
+
8
+ from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
9
+ from relik.retriever.indexers.document import DocumentStore
10
+ from relik.retriever import GoldenRetriever
11
+ from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
12
  import requests
13
  import streamlit as st
14
  from spacy import displacy
 
20
  import random
21
 
22
  from relik.inference.annotator import Relik
23
+ from relik.inference.data.objects import (
24
+ AnnotationType,
25
+ RelikOutput,
26
+ Span,
27
+ TaskType,
28
+ Triples,
29
+ )
30
 
31
  def get_random_color(ents):
32
  colors = {}
 
105
 
106
 
107
  def set_sidebar(css):
 
108
  with st.sidebar:
109
  st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
110
  st.image(
111
+ "https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg",
112
  use_column_width=True,
113
  )
114
+ st.markdown("### World Bank")
115
+ st.markdown("### DIME")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def get_el_annotations(response):
118
+ i_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-intervention/{}' style='color: #414141'> <span style='font-size: 1.0em; font-family: monospace'> Intervention {}</span></a>"
119
+ o_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://developmentevidence.3ieimpact.org/taxonomy-search-detail/intervention/disaggregated-outcome/{}' style='color: #414141'><span style='font-size: 1.0em; font-family: monospace'> Outcome: {}</span></a>"
120
  # swap labels key with ents
121
  ents = [
122
  {
123
  "start": l.start,
124
  "end": l.end,
125
+ "label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
126
+ } if io_map[l.label] == "intervention" else
127
+ {
128
+ "start": l.start,
129
+ "end": l.end,
130
+ "label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label),
131
  }
132
+ for l in response.spans
133
  ]
134
  dict_of_ents = {"text": response.text, "ents": ents}
135
  label_in_text = set(l["label"] for l in dict_of_ents["ents"])
 
137
  return dict_of_ents, options
138
 
139
 
140
+
141
+ def get_retriever_annotations(response):
142
+ el_link_wrapper = "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css'><a href='https://en.wikipedia.org/wiki/{}' style='color: #414141'><i class='fa-brands fa-wikipedia-w fa-xs'></i> <span style='font-size: 1.0em; font-family: monospace'> {}</span></a>"
143
+ # swap labels key with ents
144
+ ents = [l.text
145
+ for l in response.candidates[TaskType.SPAN]
146
+ ]
147
+ dict_of_ents = {"text": response.text, "ents": ents}
148
+ label_in_text = set(l for l in dict_of_ents["ents"])
149
+ options = {"ents": label_in_text, "colors": get_random_color(label_in_text)}
150
+ return dict_of_ents, options
151
+ import json
152
+ io_map = {}
153
+ with open("/home/user/app/model/retriever/document_index/documents.jsonl", "r") as r:
154
+ for line in r:
155
+ element = json.loads(line)
156
+ io_map[element["text"]] = element["metadata"]["type"]
157
+
158
  @st.cache_resource()
159
  def load_model():
160
+
161
+ retriever = GoldenRetriever(
162
+ question_encoder="/home/user/app/model/retriever/question_encoder",
163
+ document_index=InMemoryDocumentIndex(
164
+ documents=DocumentStore.from_file(
165
+ "/home/user/app/model/retriever/document_index/documents.jsonl"
166
+ ),
167
+ metadata_fields=["definition"],
168
+ separator=' <def> ',
169
+ device="cuda"
170
+ ),
171
+ devide="cuda"
172
+
173
  )
174
+ retriever.index()
175
 
176
+ reader = RelikReaderForSpanExtraction("/home/user/app/model/reader",
177
+ dataset_kwargs={"use_nme": True})
178
+
179
+ relik = Relik(reader=reader, retriever=retriever, window_size="none", top_k=100, task="span", device="cuda", document_index_device="cpu")
180
+
181
+ return relik
182
 
183
  def set_intro(css):
184
  # intro
185
+
186
+ st.markdown("# CausalAI")
187
+ st.image(
188
+ "http://35.237.102.64/public/logo.png",
189
+ )
190
  st.markdown(
191
+ "### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking"
192
  )
193
  # st.markdown(
194
  # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API "
 
196
  # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), "
197
  # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)."
198
  # )
 
 
 
199
 
200
  def run_client():
201
  with open(Path(__file__).parent / "style.css") as f:
202
  css = f.read()
203
 
204
  st.set_page_config(
205
+ page_title="CausalAI",
206
  page_icon="🦮",
207
  layout="wide",
208
  )
 
212
  # text input
213
  text = st.text_area(
214
  "Enter Text Below:",
215
+ value="How does unconditional cash transver affect to reduce poverty?",
216
  height=200,
217
  max_chars=1500,
218
  )
 
221
  key="annotate_button",
222
  css_styles="""
223
  button {
224
+ background-color: #a8ebff;
225
+ color: black;
226
  border-radius: 25px;
227
  }
228
  """,
 
242
  st.markdown("#### Entity Linking")
243
  with st.spinner(text="In progress"):
244
  response = relik_model(text)
245
+
246
  # response = requests.post(RELIK, json=text)
247
  # if response.status_code != 200:
248
  # st.error("Error: {}".format(response.status_code))
 
251
 
252
  # st.markdown("##")
253
  dict_of_ents, options = get_el_annotations(response=response)
254
+ dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response)
255
+
256
  display = displacy.render(
257
  dict_of_ents, manual=True, style="ent", options=options
258
  )
259
+
260
+
261
  display = display.replace("\n", " ")
262
+
263
  # heurstic, prevents split of annotation decorations
264
  display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;")
265
+
266
  with st.container():
267
  st.write(display, unsafe_allow_html=True)
268
+
269
+ text = "## Possible Candidates:\n- " + "\n- ".join([candidate for candidate in dict_of_ents_candidates["ents"][:5]])
270
+ st.markdown(text)
271
 
272
  else:
273
  st.error("Please enter some text.")