import os import re import time from pathlib import Path from relik.retriever import GoldenRetriever from relik.retriever.indexers.inmemory import InMemoryDocumentIndex from relik.retriever.indexers.document import DocumentStore from relik.retriever import GoldenRetriever from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction import requests import streamlit as st from spacy import displacy from streamlit_extras.badges import badge from streamlit_extras.stylable_container import stylable_container import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') logger = logging.getLogger() # RELIK = os.getenv("RELIK", "localhost:8000/api/entities") import random from relik.inference.annotator import Relik from relik.inference.data.objects import ( AnnotationType, RelikOutput, Span, TaskType, Triples, ) def get_random_color(ents): colors = {} random_colors = generate_pastel_colors(len(ents)) for ent in ents: colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) return colors def floatrange(start, stop, steps): if int(steps) == 1: return [stop] return [ start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) ] def hsl_to_rgb(h, s, l): def hue_2_rgb(v1, v2, v_h): while v_h < 0.0: v_h += 1.0 while v_h > 1.0: v_h -= 1.0 if 6 * v_h < 1.0: return v1 + (v2 - v1) * 6.0 * v_h if 2 * v_h < 1.0: return v2 if 3 * v_h < 2.0: return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 return v1 # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." r, b, g = (l * 255,) * 3 if s != 0.0: if l < 0.5: var_2 = l * (1.0 + s) else: var_2 = (l + s) - (s * l) var_1 = 2.0 * l - var_2 r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) g = 255 * hue_2_rgb(var_1, var_2, h) b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) return int(round(r)), int(round(g)), int(round(b)) def generate_pastel_colors(n): """Return different pastel colours. Input: n (integer) : The number of colors to return Output: A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) Example: >>> print generate_pastel_colors(5) ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] """ if n == 0: return [] # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue saturation = 1.0 lightness = 0.9 # We take points around the chromatic circle (hue): # (Note: we generate n+1 colors, then drop the last one ([:-1]) because # it equals the first one (hue 0 = hue 1)) return [ "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) for hue in floatrange(start_hue, start_hue + 1, n + 1) ][:-1] def set_sidebar(css): with st.sidebar: st.markdown(f"", unsafe_allow_html=True) st.image( "https://upload.wikimedia.org/wikipedia/commons/8/87/The_World_Bank_logo.svg", use_column_width=True, ) st.markdown("### World Bank") st.markdown("### DIME") def get_el_annotations(response): i_link_wrapper = " Intervention {}" o_link_wrapper = " Outcome: {}" # swap labels key with ents ents = [ { "start": l.start, "end": l.end, "label": i_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), } if io_map[l.label] == "intervention" else { "start": l.start, "end": l.end, "label": o_link_wrapper.format(l.label[0].upper() + l.label[1:].replace("/", "%2").replace(" ", "%20").replace("&","%26"), l.label), } for l in response.spans ] dict_of_ents = {"text": response.text, "ents": ents} label_in_text = set(l["label"] for l in dict_of_ents["ents"]) options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} return dict_of_ents, options def get_retriever_annotations(response): el_link_wrapper = " {}" # swap labels key with ents ents = [l.text for l in response.candidates[TaskType.SPAN] ] dict_of_ents = {"text": response.text, "ents": ents} label_in_text = set(l for l in dict_of_ents["ents"]) options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} return dict_of_ents, options def get_retriever_annotations_candidates(text, ents): el_link_wrapper = " {}" # swap labels key with ents dict_of_ents = {"text": text, "ents": ents} label_in_text = set(l for l in dict_of_ents["ents"]) options = {"ents": label_in_text, "colors": get_random_color(label_in_text)} return dict_of_ents, options import json io_map = {} with open("/home/user/app/models/retriever/document_index/documents.jsonl", "r") as r: for line in r: element = json.loads(line) io_map[element["text"]] = element["metadata"]["type"] import json db_set = set() with open("models/retriever/intervention/gpt/db/document_index/documents.jsonl", "r") as r: for line in r: element = json.loads(line) db_set.add(element["text"]) with open("models/retriever/outcome/gpt/db/document_index/documents.jsonl", "r") as r: for line in r: element = json.loads(line) db_set.add(element["text"]) @st.cache_resource() def load_model(): retriever_question = GoldenRetriever( question_encoder="/home/user/app/models/retriever/question_encoder", document_index="/home/user/app/models/retriever/document_index/questions" ) retriever_intervention_gpt_taxonomy = GoldenRetriever( question_encoder="models/retriever/intervention/gpt+llama/taxonomy/question_encoder", document_index="models/retriever/intervention/gpt+llama/taxonomy/document_index" ) retriever_intervention_gpt_db = GoldenRetriever( question_encoder="models/retriever/intervention/gpt+llama/db/question_encoder", document_index="models/retriever/intervention/gpt+llama/db/document_index" ) retriever_outcome_gpt_taxonomy = GoldenRetriever( question_encoder="models/retriever/outcome/gpt+llama/taxonomy/question_encoder", document_index="models/retriever/outcome/gpt+llama/taxonomy/document_index" ) retriever_outcome_gpt_db = GoldenRetriever( question_encoder="models/retriever/outcome/gpt+llama/db/question_encoder", document_index="models/retriever/outcome/gpt+llama/db/document_index" ) reader = RelikReaderForSpanExtraction("/home/user/app/models/small-extended-large-batch", dataset_kwargs={"use_nme": True}) relik_question = Relik(reader=reader, retriever=retriever_question, window_size="none", top_k=100, task="span", device="cpu", document_index_device="cpu") return [relik_question, retriever_intervention_gpt_db, retriever_outcome_gpt_db, retriever_intervention_gpt_taxonomy, retriever_outcome_gpt_taxonomy] def set_intro(css): # intro st.markdown("# ImpactAI") st.image( "http://35.237.102.64/public/logo.png", ) st.markdown( "### 3ie taxonomy level 4 Intervention/Outcome candidate retriever with Entity Linking" ) # st.markdown( # "This is a front-end for the paper [Universal Semantic Annotator: the First Unified API " # "for WSD, SRL and Semantic Parsing](https://www.researchgate.net/publication/360671045_Universal_Semantic_Annotator_the_First_Unified_API_for_WSD_SRL_and_Semantic_Parsing), which will be presented at LREC 2022 by " # "[Riccardo Orlando](https://riccorl.github.io), [Simone Conia](https://c-simone.github.io/), " # "[Stefano Faralli](https://corsidilaurea.uniroma1.it/it/users/stefanofaralliuniroma1it), and [Roberto Navigli](https://www.diag.uniroma1.it/navigli/)." # ) from datetime import datetime from pathlib import Path from huggingface_hub import HfApi, CommitScheduler from uuid import uuid4 # Access token from environment variable # api = HfApi() # api.set_access_token(os.getenv("HF_TOKEN")) JSON_DATASET_DIR = Path("json_demo_selected_io") JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json" scheduler = CommitScheduler( repo_id="demo-retriever", repo_type="dataset", folder_path=JSON_DATASET_DIR, path_in_repo="data", token=os.getenv("HF_TOKEN") ) def write_candidates_to_file(text, candidates, selected_candidates): logger.info(f"Text: {text}\tCandidates: {str(candidates)}\tSelected Candidates: {str(selected_candidates)}\n") with scheduler.lock: with JSON_DATASET_PATH.open("a") as f: json.dump({"text": text, "Candidates": [candidate for candidate in candidates], "Selected Candidates": [candidate for candidate in selected_candidates], "datetime": datetime.now().isoformat()}, f) f.write("\n") def run_client(): with open(Path(__file__).parent / "style.css") as f: css = f.read() st.set_page_config( page_title="ImpactAI", page_icon="🦮", layout="wide", ) set_sidebar(css) set_intro(css) # Radio button selection analysis_type = st.radio( "Choose analysis type:", options=["Retriever", "Entity Linking"], index=0 # Default to 'question' ) selection_options = ["DB Intervention", "DB Outcome", "Taxonomy Intervention", "Taxonomy Outcome", "Top-k DB in Taxonomy Intervention", "Top-k DB in Taxonmy Outcome", ] if analysis_type == "Retriever": # Selection list using selectbox selection_list = st.selectbox( "Select an option:", options=selection_options ) # text input text = st.text_area( "Enter Text Below:", value="", height=200, max_chars=1500, ) with stylable_container( key="annotate_button", css_styles=""" button { background-color: #a8ebff; color: black; border-radius: 25px; } """, ): submit = st.button("Annotate") # submit = st.button("Run") if "relik_model" not in st.session_state.keys(): st.session_state["relik_model"] = load_model() relik_model = st.session_state["relik_model"][0] if 'candidates' not in st.session_state: st.session_state['candidates'] = [] if 'selected_candidates' not in st.session_state: st.session_state['selected_candidates'] = [] # ReLik API call if submit: if analysis_type == "Entity Linking": relik_model = st.session_state["relik_model"][0] else: model_idx = selection_options.index(selection_list) if selection_list == "Top-k DB in Taxonomy Intervention" or selection_list == "Top-k DB in Taxonmy Outcome": relik_model = st.session_state["relik_model"][model_idx-1] else: relik_model = st.session_state["relik_model"][model_idx+1] text = text.strip() if text: st.markdown("####") with st.spinner(text="In progress"): if analysis_type == "Entity Linking": response = relik_model(text) dict_of_ents, options = get_el_annotations(response=response) dict_of_ents_candidates, options_candidates = get_retriever_annotations(response=response) st.markdown("#### Entity Linking") display = displacy.render( dict_of_ents, manual=True, style="ent", options=options ) display = display.replace("\n", " ") # heurstic, prevents split of annotation decorations display = display.replace("border-radius: 0.35em;", "border-radius: 0.35em; white-space: nowrap;") with st.container(): st.write(display, unsafe_allow_html=True) candidate_text = "".join(f"
  • Intervention: {candidate}
  • " if io_map[candidate] == "intervention" else f"
  • Outcome: {candidate}
  • " for candidate in dict_of_ents_candidates["ents"][0:10]) text = """

    Possible Candidates:

    " st.markdown(text, unsafe_allow_html=True) else: if selection_list == "Top-k DB in Taxonomy Intervention" or selection_list == "Top-k DB in Taxonomy Outcome": response = relik_model.retrieve(text, k=50, batch_size=400, progress_bar=False) candidates_text = [pred.document.text for pred in response[0] if pred.document.text in db_set] candidates_text = candidates_text[:10] else: response = relik_model.retrieve(text, k=20, batch_size=400, progress_bar=False) candidates_text = [pred.document.text for pred in response[0]] if candidates_text: st.session_state.candidates = candidates_text else: st.session_state.candidates = [] st.session_state.selected_candidates = [] st.markdown("

    No Candidates Found

    ", unsafe_allow_html=True) else: st.error("Please enter some text.") # Ensure the candidates list is displayed even after interactions if st.session_state.candidates and analysis_type != "Entity Linking": dict_of_ents_candidates, options_candidates = get_retriever_annotations_candidates(text, st.session_state.candidates) st.markdown("

    Possible Candidates:

    ", unsafe_allow_html=True) for candidate in dict_of_ents_candidates["ents"]: checked = candidate in st.session_state.selected_candidates if st.checkbox(candidate, key=candidate, value=checked): if candidate not in st.session_state.selected_candidates: st.session_state.selected_candidates.append(candidate) else: if candidate in st.session_state.selected_candidates: st.session_state.selected_candidates.remove(candidate) if st.button("Save Selected Candidates"): if write_candidates_to_file(text, dict_of_ents_candidates["ents"], st.session_state.selected_candidates): st.success("Selected candidates have been saved to file.") if __name__ == "__main__": run_client()