import os from typing import Dict, List import pandas as pd import datasets import streamlit as st import config from findkit import retrieval_pipeline from search_utils import ( get_repos_with_descriptions, search_f, merge_text_list_cols, setup_retrieval_pipeline, ) class RetrievalApp: def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"): print("loading data") raw_retrieval_df = ( datasets.load_dataset(data_path)["train"] .to_pandas() .drop_duplicates(subset=["repo"]) .reset_index(drop=True) ) self.retrieval_df = merge_text_list_cols( raw_retrieval_df, config.text_list_cols ) model_name = st.sidebar.selectbox("model", config.model_names) self.query_encoder_name = "lambdaofgod/query-" + model_name self.document_encoder_name = "lambdaofgod/document-" + model_name st.sidebar.text("using models") st.sidebar.text("https://huggingface.co/" + self.query_encoder_name) st.sidebar.text("https://huggingface.co/" + self.document_encoder_name) @staticmethod def show_retrieval_results( retrieval_pipe: retrieval_pipeline.RetrievalPipeline, query: str, k: int, all_queries: List[str], description_length: int, repos_by_query: Dict[str, pd.DataFrame], doc_col: str, ): print("started retrieval") if query in all_queries: with st.expander( "query is in gold standard set queries. Toggle viewing gold standard results?" ): st.write("gold standard results") task_repos = repos_by_query.get_group(query) st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos)) with st.spinner(text="fetching results"): st.write( search_f(retrieval_pipe, query, k, description_length, doc_col).to_html( escape=False, index=False ), unsafe_allow_html=True, ) print("finished retrieval") @staticmethod def app(retrieval_pipeline, retrieval_df, doc_col): retrieved_results = st.sidebar.number_input("number of results", value=10) description_length = st.sidebar.number_input( "number of used description words", value=10 ) tasks_deduped = ( retrieval_df["tasks"].explode().value_counts().reset_index() ) # drop_duplicates().sort_values().reset_index(drop=True) tasks_deduped.columns = ["task", "documents per task"] with st.sidebar.expander("View test set queries"): st.table(tasks_deduped.explode("task")) additional_shown_cols = st.sidebar.multiselect( label="additional cols", options=config.text_cols, default=doc_col ) repos_by_query = retrieval_df.explode("tasks").groupby("tasks") query = st.text_input("input query", value="metric learning") RetrievalApp.show_retrieval_results( retrieval_pipeline, query, retrieved_results, tasks_deduped["task"].to_list(), description_length, repos_by_query, additional_shown_cols, ) def main(self): print("setting up retrieval_pipe") doc_col = "dependencies" retrieval_pipeline = setup_retrieval_pipeline( self.query_encoder_name, self.document_encoder_name, self.retrieval_df[doc_col], self.retrieval_df, ) RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col)