import os import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) import pandas as pd from src.rag.pipeline import RAGPipeline import streamlit as st from src.utils.data import ( build_filter, get_filter_values, get_meta, load_json, load_css, ) from src.utils.writer import typewriter st.set_page_config(layout="wide") EMBEDDING_MODEL = "sentence-transformers/distiluse-base-multilingual-cased-v1" PROMPT_TEMPLATE = os.path.join("src", "rag", "prompt_template.yaml") @st.cache_data def load_css_style(path: str) -> None: load_css(path) @st.cache_data def get_meta_data() -> pd.DataFrame: return pd.read_csv( os.path.join("database", "meta_data.csv"), dtype=({"retriever_id": str}) ) @st.cache_data def get_authors_taxonomy() -> dict[str, list[str]]: return load_json(os.path.join("data", "authors_filter.json")) @st.cache_data def get_draft_cat_taxonomy() -> dict[str, list[str]]: return load_json(os.path.join("data", "draftcat_taxonomy_filter.json")) @st.cache_data def get_example_prompts() -> list[str]: return [ example["question"] for example in load_json(os.path.join("data", "example_prompts.json")) ] @st.cache_resource def load_pipeline() -> RAGPipeline: return RAGPipeline( embedding_model=EMBEDDING_MODEL, prompt_template=PROMPT_TEMPLATE, ) @st.cache_data def load_app_init() -> None: # Define the title of the app st.title("INC Plastic Treaty - Q&A") # add warning emoji and style st.markdown( """
⚠️ Remark: The app is a beta version that serves as a basis for further development. We are aware that the performance is not yet sufficient and that the data basis is not yet complete. We are grateful for any feedback that contributes to the further development and improvement of the app! """, unsafe_allow_html=True, ) # add explanation to the app st.markdown( """
The app aims to facilitate the search for information and documents related to the UN Plastics Treaty Negotiations. The database includes all relevant documents that are available here. Users can query the data through a chatbot. Please note that, due to technical constraints, only a maximum of 10 documents can be used to generate the answer. A comprehensive response can therefore not be guaranteed. However, all relevant documents can be accessed via a link using the filter functions. Filter functions are available to narrow down the data by country/author, zero draft categories and negotiation rounds. Pre-selecting relevant data enhances the accuracy of generated answers. Additionally, all documents selected via the filter function can be accessed via a link. """, unsafe_allow_html=True, ) load_css_style("style/style.css") # Load the data metadata = get_meta_data() authors_taxonomy = get_authors_taxonomy() draft_cat_taxonomy = get_draft_cat_taxonomy() example_prompts = get_example_prompts() # Load pipeline pipeline = load_pipeline() # Load app init load_app_init() filter_col = st.columns(1) # Filter column with filter_col[0]: st.markdown("## Select Filters") author_col, round_col, draft_cat_col = st.columns([1, 1, 1]) with author_col: st.markdown("### Authors") selected_author_parent = st.multiselect( "Entity Parent", list(authors_taxonomy.keys()) ) available_child_items = [] for category in selected_author_parent: available_child_items.extend(authors_taxonomy[category]) selected_authors = st.multiselect("Entity", available_child_items) with round_col: st.markdown("### Round") negotiation_rounds = get_filter_values(metadata, "round") selected_rounds = st.multiselect("Round", negotiation_rounds) with draft_cat_col: st.markdown("### Draft Categories") selected_draft_cats_parent = st.multiselect( "Draft Categories Parent", list(draft_cat_taxonomy.keys()) ) available_draft_cats_child_items = [] for category in selected_draft_cats_parent: available_draft_cats_child_items.extend(draft_cat_taxonomy[category]) selected_draft_cats = st.multiselect( "Draft Categories", available_draft_cats_child_items ) prompt_col, output_col = st.columns([1, 1.5]) # make the buttons text smaller # GPT column with prompt_col: st.markdown("## Filter documents") st.markdown( """ * The filter function allows you to see all documents that match the selected filters. * Additionally, all documents selected via the filter function can be accessed via a link. * Alternatively, you can ask a question to the model. The model will then provide you with an answer based on the filtered documents. """ ) trigger_filter = st.session_state.setdefault("trigger", False) if st.button("Filter documents"): filter_selection_transformed = build_filter( meta_data=metadata, authors_filter=selected_authors, draft_cats_filter=selected_draft_cats, round_filter=selected_rounds, ) documents = pipeline.document_store.get_all_documents( filters=filter_selection_transformed ) trigger_filter = True st.markdown("## Ask a question") if "prompt" not in st.session_state: prompt = st.text_area("") if ( "prompt" in st.session_state and st.session_state.prompt in example_prompts # noqa: E501 ): # noqa: E501 prompt = st.text_area( "Enter a question", value=st.session_state.prompt ) # noqa: E501 if ( "prompt" in st.session_state and st.session_state.prompt not in example_prompts # noqa: E501 ): # noqa: E501 del st.session_state["prompt"] prompt = st.text_area("Enter a question") trigger_ask = st.session_state.setdefault("trigger", False) if st.button("Ask"): with st.status("Filtering documents...", expanded=False) as status: if filter_selection_transformed == {}: st.warning( "No filters selected. We highly recommend to use filters otherwise the answer might not be accurate. In addition you might experience performance issues since the model has to analyze all available documents." ) filter_selection_transformed = build_filter( meta_data=metadata, authors_filter=selected_authors, draft_cats_filter=selected_draft_cats, round_filter=selected_rounds, ) documents = pipeline.document_store.get_all_documents( filters=filter_selection_transformed ) status.update( label="Filtering documents completed!", state="complete", expanded=False ) with st.status("Answering question...", expanded=True) as status: result = pipeline(prompt=prompt, filters=filter_selection_transformed) trigger_ask = True status.update( label="Answering question completed!", state="complete", expanded=False ) st.markdown("### Examples") st.markdown( """ * These are example prompts that can be used to ask questions to the model * Click on a prompt to use it as a question. You can also type your own question in the text area above. * For questions like "How do country a, b and c [...]" please make sure to select the countries in the filter section. Otherwise the answer will not be accurate. In general we highly recommend to use the filter functions to narrow down the data. """ ) for i, prompt in enumerate(example_prompts): # with col[i % 4]: if st.button(prompt): if "key" not in st.session_state: st.session_state["prompt"] = prompt # Define the button if trigger_ask: with output_col: meta_data = get_meta(result=result) answer = result["answers"][0].answer meta_data_cleaned = [] seen_retriever_ids = set() for data in meta_data: retriever_id = data["retriever_id"] content = data["content"] if retriever_id not in seen_retriever_ids: meta_data_cleaned.append( { "retriever_id": retriever_id, "href": data["href"], "content": [content], } ) seen_retriever_ids.add(retriever_id) else: for i, item in enumerate(meta_data_cleaned): if item["retriever_id"] == retriever_id: meta_data_cleaned[i]["content"].append(content) references = ["\n"] for data in meta_data_cleaned: retriever_id = data["retriever_id"] href = data["href"] references.append(f"-[{retriever_id}]: {href} \n") st.write("#### 📌 Answer") typewriter( text=answer, references=references, speed=100, ) with st.expander("Show more information to the documents"): for data in meta_data_cleaned: markdown_text = f"- Document: {data['retriever_id']}\n" markdown_text += " - Text passages\n" for content in data["content"]: content = content.replace("[", "").replace("]", "").replace("'", "") content = " ".join(content.split()) markdown_text += f" - {content}\n" st.write(markdown_text) col4 = st.columns(1) with col4[0]: references = [] for document in documents: authors = document.meta["author"] authors = authors.replace("'", "").replace("[", "").replace("]", "") href = document.meta["href"] source = f"- {authors}: {href}" references.append(source) references = list(set(references)) references = sorted(references) st.markdown("### Overview of all filtered documents") st.markdown( f"
The answer above results from the most similar text passages (top 7) from the documents that you can find under 'References' in the answer block. Below you will find an overview of all documents that match the filters you have selected. Please note that the above answer is based specifically on the highlighted references above and does not include the findings from all the filtered documents shown below. \n For your current filtering, {len(references)} documents were found.
", unsafe_allow_html=True, ) for reference in references: st.write(reference) trigger = 0 if trigger_filter: with output_col: references = [] for document in documents: authors = document.meta["author"] authors = authors.replace("'", "").replace("[", "").replace("]", "") href = document.meta["href"] round_ = document.meta["round"] draft_labs = document.meta["draft_labs"] references.append( { "author": authors, "href": href, "draft_labs": draft_labs, "round": round_, } ) references = pd.DataFrame(references) references = references.drop_duplicates() st.markdown("### Overview of all filtered documents") # show # make columns author and draft_labs bigger and make href width smaller and round width smaller st.dataframe( references, hide_index=True, column_config={ "author": st.column_config.ListColumn("Authors"), "href": st.column_config.LinkColumn("Link to Document"), "draft_labs": st.column_config.ListColumn("Draft Categories"), "round": st.column_config.NumberColumn("Round"), }, )