Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import sys | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) | |
import pandas as pd | |
from src.rag.pipeline import RAGPipeline | |
import streamlit as st | |
from src.utils.data import ( | |
build_filter, | |
get_filter_values, | |
get_meta, | |
load_json, | |
load_css, | |
) | |
from src.utils.writer import typewriter | |
st.set_page_config(layout="wide") | |
EMBEDDING_MODEL = "sentence-transformers/distiluse-base-multilingual-cased-v1" | |
PROMPT_TEMPLATE = os.path.join("src", "rag", "prompt_template.yaml") | |
def load_css_style(path: str) -> None: | |
load_css(path) | |
def get_meta_data() -> pd.DataFrame: | |
return pd.read_csv( | |
os.path.join("database", "meta_data.csv"), dtype=({"retriever_id": str}) | |
) | |
def get_authors_taxonomy() -> dict[str, list[str]]: | |
return load_json(os.path.join("data", "authors_filter.json")) | |
def get_draft_cat_taxonomy() -> dict[str, list[str]]: | |
return load_json(os.path.join("data", "draftcat_taxonomy_filter.json")) | |
def get_example_prompts() -> list[str]: | |
return [ | |
example["question"] | |
for example in load_json(os.path.join("data", "example_prompts.json")) | |
] | |
def load_pipeline() -> RAGPipeline: | |
return RAGPipeline( | |
embedding_model=EMBEDDING_MODEL, | |
prompt_template=PROMPT_TEMPLATE, | |
) | |
def load_app_init() -> None: | |
# Define the title of the app | |
st.title("INC Plastic Treaty - Q&A") | |
# add warning emoji and style | |
st.markdown( | |
""" | |
<p class="remark"> ⚠️ Remark: | |
The app is a beta version that serves as a basis for further development. We are aware that the performance is not yet sufficient and that the data basis is not yet complete. We are grateful for any feedback that contributes to the further development and improvement of the app! | |
""", | |
unsafe_allow_html=True, | |
) | |
# add explanation to the app | |
st.markdown( | |
""" | |
<p class="description"> | |
The app aims to facilitate the search for information and documents related to the UN Plastics Treaty Negotiations. The database includes all relevant documents that are available <a href=https://www.unep.org/inc-plastic-pollution target="_blank">here</a>. Users can query the data through a chatbot. Please note that, due to technical constraints, only a maximum of 10 documents can be used to generate the answer. A comprehensive response can therefore not be guaranteed. However, all relevant documents can be accessed via a link using the filter functions. | |
Filter functions are available to narrow down the data by country/author, zero draft categories and negotiation rounds. Pre-selecting relevant data enhances the accuracy of generated answers. Additionally, all documents selected via the filter function can be accessed via a link. | |
""", | |
unsafe_allow_html=True, | |
) | |
load_css_style("style/style.css") | |
# Load the data | |
metadata = get_meta_data() | |
authors_taxonomy = get_authors_taxonomy() | |
draft_cat_taxonomy = get_draft_cat_taxonomy() | |
example_prompts = get_example_prompts() | |
# Load pipeline | |
pipeline = load_pipeline() | |
# Load app init | |
load_app_init() | |
filter_col = st.columns(1) | |
# Filter column | |
with filter_col[0]: | |
st.markdown("## Select Filters") | |
author_col, round_col, draft_cat_col = st.columns([1, 1, 1]) | |
with author_col: | |
st.markdown("### Authors") | |
selected_author_parent = st.multiselect( | |
"Entity Parent", list(authors_taxonomy.keys()) | |
) | |
available_child_items = [] | |
for category in selected_author_parent: | |
available_child_items.extend(authors_taxonomy[category]) | |
selected_authors = st.multiselect("Entity", available_child_items) | |
with round_col: | |
st.markdown("### Round") | |
negotiation_rounds = get_filter_values(metadata, "round") | |
selected_rounds = st.multiselect("Round", negotiation_rounds) | |
with draft_cat_col: | |
st.markdown("### Draft Categories") | |
selected_draft_cats_parent = st.multiselect( | |
"Draft Categories Parent", list(draft_cat_taxonomy.keys()) | |
) | |
available_draft_cats_child_items = [] | |
for category in selected_draft_cats_parent: | |
available_draft_cats_child_items.extend(draft_cat_taxonomy[category]) | |
selected_draft_cats = st.multiselect( | |
"Draft Categories", available_draft_cats_child_items | |
) | |
prompt_col, output_col = st.columns([1, 1.5]) | |
# make the buttons text smaller | |
# GPT column | |
with prompt_col: | |
st.markdown("## Filter documents") | |
st.markdown( | |
""" | |
* The filter function allows you to see all documents that match the selected filters. | |
* Additionally, all documents selected via the filter function can be accessed via a link. | |
* Alternatively, you can ask a question to the model. The model will then provide you with an answer based on the filtered documents. | |
""" | |
) | |
trigger_filter = st.session_state.setdefault("trigger", False) | |
if st.button("Filter documents"): | |
filter_selection_transformed = build_filter( | |
meta_data=metadata, | |
authors_filter=selected_authors, | |
draft_cats_filter=selected_draft_cats, | |
round_filter=selected_rounds, | |
) | |
documents = pipeline.document_store.get_all_documents( | |
filters=filter_selection_transformed | |
) | |
trigger_filter = True | |
st.markdown("## Ask a question") | |
if "prompt" not in st.session_state: | |
prompt = st.text_area("") | |
if ( | |
"prompt" in st.session_state | |
and st.session_state.prompt in example_prompts # noqa: E501 | |
): # noqa: E501 | |
prompt = st.text_area( | |
"Enter a question", value=st.session_state.prompt | |
) # noqa: E501 | |
if ( | |
"prompt" in st.session_state | |
and st.session_state.prompt not in example_prompts # noqa: E501 | |
): # noqa: E501 | |
del st.session_state["prompt"] | |
prompt = st.text_area("Enter a question") | |
trigger_ask = st.session_state.setdefault("trigger", False) | |
if st.button("Ask"): | |
with st.status("Filtering documents...", expanded=False) as status: | |
if filter_selection_transformed == {}: | |
st.warning( | |
"No filters selected. We highly recommend to use filters otherwise the answer might not be accurate. In addition you might experience performance issues since the model has to analyze all available documents." | |
) | |
filter_selection_transformed = build_filter( | |
meta_data=metadata, | |
authors_filter=selected_authors, | |
draft_cats_filter=selected_draft_cats, | |
round_filter=selected_rounds, | |
) | |
documents = pipeline.document_store.get_all_documents( | |
filters=filter_selection_transformed | |
) | |
status.update( | |
label="Filtering documents completed!", state="complete", expanded=False | |
) | |
with st.status("Answering question...", expanded=True) as status: | |
result = pipeline(prompt=prompt, filters=filter_selection_transformed) | |
trigger_ask = True | |
status.update( | |
label="Answering question completed!", state="complete", expanded=False | |
) | |
st.markdown("### Examples") | |
st.markdown( | |
""" | |
* These are example prompts that can be used to ask questions to the model | |
* Click on a prompt to use it as a question. You can also type your own question in the text area above. | |
* For questions like "How do country a, b and c [...]" please make sure to select the countries in the filter section. Otherwise the answer will not be accurate. In general we highly recommend to use the filter functions to narrow down the data. | |
""" | |
) | |
for i, prompt in enumerate(example_prompts): | |
# with col[i % 4]: | |
if st.button(prompt): | |
if "key" not in st.session_state: | |
st.session_state["prompt"] = prompt | |
# Define the button | |
if trigger_ask: | |
with output_col: | |
meta_data = get_meta(result=result) | |
answer = result["answers"][0].answer | |
meta_data_cleaned = [] | |
seen_retriever_ids = set() | |
for data in meta_data: | |
retriever_id = data["retriever_id"] | |
content = data["content"] | |
if retriever_id not in seen_retriever_ids: | |
meta_data_cleaned.append( | |
{ | |
"retriever_id": retriever_id, | |
"href": data["href"], | |
"content": [content], | |
} | |
) | |
seen_retriever_ids.add(retriever_id) | |
else: | |
for i, item in enumerate(meta_data_cleaned): | |
if item["retriever_id"] == retriever_id: | |
meta_data_cleaned[i]["content"].append(content) | |
references = ["\n"] | |
for data in meta_data_cleaned: | |
retriever_id = data["retriever_id"] | |
href = data["href"] | |
references.append(f"-[{retriever_id}]: {href} \n") | |
st.write("#### 📌 Answer") | |
typewriter( | |
text=answer, | |
references=references, | |
speed=100, | |
) | |
with st.expander("Show more information to the documents"): | |
for data in meta_data_cleaned: | |
markdown_text = f"- Document: {data['retriever_id']}\n" | |
markdown_text += " - Text passages\n" | |
for content in data["content"]: | |
content = content.replace("[", "").replace("]", "").replace("'", "") | |
content = " ".join(content.split()) | |
markdown_text += f" - {content}\n" | |
st.write(markdown_text) | |
col4 = st.columns(1) | |
with col4[0]: | |
references = [] | |
for document in documents: | |
authors = document.meta["author"] | |
authors = authors.replace("'", "").replace("[", "").replace("]", "") | |
href = document.meta["href"] | |
source = f"- {authors}: {href}" | |
references.append(source) | |
references = list(set(references)) | |
references = sorted(references) | |
st.markdown("### Overview of all filtered documents") | |
st.markdown( | |
f"<p class='description'> The answer above results from the most similar text passages (top 7) from the documents that you can find under 'References' in the answer block. Below you will find an overview of all documents that match the filters you have selected. Please note that the above answer is based specifically on the highlighted references above and does not include the findings from all the filtered documents shown below. \n For your current filtering, {len(references)} documents were found. </p>", | |
unsafe_allow_html=True, | |
) | |
for reference in references: | |
st.write(reference) | |
trigger = 0 | |
if trigger_filter: | |
with output_col: | |
references = [] | |
for document in documents: | |
authors = document.meta["author"] | |
authors = authors.replace("'", "").replace("[", "").replace("]", "") | |
href = document.meta["href"] | |
round_ = document.meta["round"] | |
draft_labs = document.meta["draft_labs"] | |
references.append( | |
{ | |
"author": authors, | |
"href": href, | |
"draft_labs": draft_labs, | |
"round": round_, | |
} | |
) | |
references = pd.DataFrame(references) | |
references = references.drop_duplicates() | |
st.markdown("### Overview of all filtered documents") | |
# show | |
# make columns author and draft_labs bigger and make href width smaller and round width smaller | |
st.dataframe( | |
references, | |
hide_index=True, | |
column_config={ | |
"author": st.column_config.ListColumn("Authors"), | |
"href": st.column_config.LinkColumn("Link to Document"), | |
"draft_labs": st.column_config.ListColumn("Draft Categories"), | |
"round": st.column_config.NumberColumn("Round"), | |
}, | |
) | |