Spaces:
Runtime error
Runtime error
File size: 3,731 Bytes
568499b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
from typing import Dict, List
import pandas as pd
import datasets
import streamlit as st
import config
from findkit import retrieval_pipeline
from search_utils import (
get_repos_with_descriptions,
search_f,
merge_text_list_cols,
setup_retrieval_pipeline,
)
class RetrievalApp:
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
print("loading data")
raw_retrieval_df = (
datasets.load_dataset(data_path)["train"]
.to_pandas()
.drop_duplicates(subset=["repo"])
.reset_index(drop=True)
)
self.retrieval_df = merge_text_list_cols(
raw_retrieval_df, config.text_list_cols
)
model_name = st.sidebar.selectbox("model", config.model_names)
self.query_encoder_name = "lambdaofgod/query-" + model_name
self.document_encoder_name = "lambdaofgod/document-" + model_name
st.sidebar.text("using models")
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
st.sidebar.text("https://huggingface.co/" + self.document_encoder_name)
@staticmethod
def show_retrieval_results(
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
query: str,
k: int,
all_queries: List[str],
description_length: int,
repos_by_query: Dict[str, pd.DataFrame],
doc_col: str,
):
print("started retrieval")
if query in all_queries:
with st.expander(
"query is in gold standard set queries. Toggle viewing gold standard results?"
):
st.write("gold standard results")
task_repos = repos_by_query.get_group(query)
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
with st.spinner(text="fetching results"):
st.write(
search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
escape=False, index=False
),
unsafe_allow_html=True,
)
print("finished retrieval")
@staticmethod
def app(retrieval_pipeline, retrieval_df, doc_col):
retrieved_results = st.sidebar.number_input("number of results", value=10)
description_length = st.sidebar.number_input(
"number of used description words", value=10
)
tasks_deduped = (
retrieval_df["tasks"].explode().value_counts().reset_index()
) # drop_duplicates().sort_values().reset_index(drop=True)
tasks_deduped.columns = ["task", "documents per task"]
with st.sidebar.expander("View test set queries"):
st.table(tasks_deduped.explode("task"))
additional_shown_cols = st.sidebar.multiselect(
label="additional cols", options=config.text_cols, default=doc_col
)
repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
query = st.text_input("input query", value="metric learning")
RetrievalApp.show_retrieval_results(
retrieval_pipeline,
query,
retrieved_results,
tasks_deduped["task"].to_list(),
description_length,
repos_by_query,
additional_shown_cols,
)
def main(self):
print("setting up retrieval_pipe")
doc_col = "dependencies"
retrieval_pipeline = setup_retrieval_pipeline(
self.query_encoder_name,
self.document_encoder_name,
self.retrieval_df[doc_col],
self.retrieval_df,
)
RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col)
|