paperswithcode_nbow / app_implementation.py
lambdaofgod's picture
cuda bug, file renames
655f181
from typing import Dict, List
import torch
import pandas as pd
import streamlit as st
from findkit import retrieval_pipeline
import config
from search_utils import (
RetrievalPipelineWrapper,
get_doc_cols,
get_repos_with_descriptions,
get_retrieval_df,
merge_cols,
)
class RetrievalApp:
def is_cuda_available(self):
try:
torch._C._cuda_init()
except:
return False
return True
def get_device_options(self):
if self.is_cuda_available():
return ["cuda", "cpu"]
else:
return ["cpu"]
@st.cache(allow_output_mutation=True)
def get_retrieval_df(self):
return get_retrieval_df(self.data_path, config.text_list_cols)
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
self.data_path = data_path
self.device = st.sidebar.selectbox("device", self.get_device_options())
print("loading data")
self.retrieval_df = self.get_retrieval_df().copy()
model_name = st.sidebar.selectbox("model", config.model_names)
self.query_encoder_name = "lambdaofgod/query-" + model_name
self.document_encoder_name = "lambdaofgod/document-" + model_name
doc_cols = get_doc_cols(model_name)
st.sidebar.text("using models")
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)
self.additional_shown_cols = st.sidebar.multiselect(
label="used text features", options=config.text_cols, default=doc_cols
)
@staticmethod
def show_retrieval_results(
retrieval_pipe: RetrievalPipelineWrapper,
query: str,
k: int,
all_queries: List[str],
description_length: int,
repos_by_query: Dict[str, pd.DataFrame],
additional_shown_cols: List[str],
):
print("started retrieval")
if query in all_queries:
with st.expander(
"query is in gold standard set queries. Toggle viewing gold standard results?"
):
st.write("gold standard results")
task_repos = repos_by_query.get_group(query)
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
with st.spinner(text="fetching results"):
st.write(
retrieval_pipe.search(
query, k, description_length, additional_shown_cols
).to_html(escape=False, index=False),
unsafe_allow_html=True,
)
print("finished retrieval")
def run_app(self, retrieval_pipeline):
retrieved_results = st.sidebar.number_input("number of results", value=10)
description_length = st.sidebar.number_input(
"number of used description words", value=10
)
tasks_deduped = (
self.retrieval_df["tasks"].explode().value_counts().reset_index()
) # drop_duplicates().sort_values().reset_index(drop=True)
tasks_deduped.columns = ["task", "documents per task"]
with st.sidebar.expander("View test set queries"):
st.table(tasks_deduped.explode("task"))
repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
query = st.text_input("input query", value="metric learning")
RetrievalApp.show_retrieval_results(
retrieval_pipeline,
query,
retrieved_results,
tasks_deduped["task"].to_list(),
description_length,
repos_by_query,
self.additional_shown_cols,
)
@st.cache(allow_output_mutation=True)
def get_retrieval_pipeline(self, displayed_retrieval_df):
return RetrievalPipelineWrapper.setup_from_encoder_names(
self.query_encoder_name,
self.document_encoder_name,
displayed_retrieval_df["document"],
displayed_retrieval_df,
device=self.device,
)
def main(self):
print("setting up retrieval_pipe")
displayed_retrieval_df = merge_cols(
self.retrieval_df.copy(), self.additional_shown_cols
)
retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
self.run_app(retrieval_pipeline)