Spaces:
Runtime error
Runtime error
Commit
·
a284f57
1
Parent(s):
568499b
additional cols and optional device
Browse files- app_implementation.py +48 -36
- config.py +7 -7
- search_utils.py +74 -44
app_implementation.py
CHANGED
|
@@ -1,50 +1,61 @@
|
|
| 1 |
-
import os
|
| 2 |
from typing import Dict, List
|
| 3 |
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
-
import datasets
|
| 6 |
import streamlit as st
|
| 7 |
-
import config
|
| 8 |
from findkit import retrieval_pipeline
|
|
|
|
|
|
|
| 9 |
from search_utils import (
|
|
|
|
|
|
|
| 10 |
get_repos_with_descriptions,
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
setup_retrieval_pipeline,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
| 17 |
class RetrievalApp:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
|
|
|
|
|
|
|
| 19 |
print("loading data")
|
| 20 |
|
| 21 |
-
|
| 22 |
-
datasets.load_dataset(data_path)["train"]
|
| 23 |
-
.to_pandas()
|
| 24 |
-
.drop_duplicates(subset=["repo"])
|
| 25 |
-
.reset_index(drop=True)
|
| 26 |
-
)
|
| 27 |
-
self.retrieval_df = merge_text_list_cols(
|
| 28 |
-
raw_retrieval_df, config.text_list_cols
|
| 29 |
-
)
|
| 30 |
|
| 31 |
model_name = st.sidebar.selectbox("model", config.model_names)
|
| 32 |
self.query_encoder_name = "lambdaofgod/query-" + model_name
|
| 33 |
self.document_encoder_name = "lambdaofgod/document-" + model_name
|
| 34 |
|
|
|
|
|
|
|
| 35 |
st.sidebar.text("using models")
|
| 36 |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
|
| 37 |
-
st.sidebar.text("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@staticmethod
|
| 40 |
def show_retrieval_results(
|
| 41 |
-
retrieval_pipe:
|
| 42 |
query: str,
|
| 43 |
k: int,
|
| 44 |
all_queries: List[str],
|
| 45 |
description_length: int,
|
| 46 |
repos_by_query: Dict[str, pd.DataFrame],
|
| 47 |
-
|
| 48 |
):
|
| 49 |
print("started retrieval")
|
| 50 |
if query in all_queries:
|
|
@@ -56,15 +67,14 @@ class RetrievalApp:
|
|
| 56 |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
|
| 57 |
with st.spinner(text="fetching results"):
|
| 58 |
st.write(
|
| 59 |
-
|
| 60 |
escape=False, index=False
|
| 61 |
),
|
| 62 |
unsafe_allow_html=True,
|
| 63 |
)
|
| 64 |
print("finished retrieval")
|
| 65 |
|
| 66 |
-
|
| 67 |
-
def app(retrieval_pipeline, retrieval_df, doc_col):
|
| 68 |
|
| 69 |
retrieved_results = st.sidebar.number_input("number of results", value=10)
|
| 70 |
description_length = st.sidebar.number_input(
|
|
@@ -72,17 +82,12 @@ class RetrievalApp:
|
|
| 72 |
)
|
| 73 |
|
| 74 |
tasks_deduped = (
|
| 75 |
-
retrieval_df["tasks"].explode().value_counts().reset_index()
|
| 76 |
) # drop_duplicates().sort_values().reset_index(drop=True)
|
| 77 |
tasks_deduped.columns = ["task", "documents per task"]
|
| 78 |
with st.sidebar.expander("View test set queries"):
|
| 79 |
st.table(tasks_deduped.explode("task"))
|
| 80 |
-
|
| 81 |
-
additional_shown_cols = st.sidebar.multiselect(
|
| 82 |
-
label="additional cols", options=config.text_cols, default=doc_col
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
|
| 86 |
query = st.text_input("input query", value="metric learning")
|
| 87 |
RetrievalApp.show_retrieval_results(
|
| 88 |
retrieval_pipeline,
|
|
@@ -91,16 +96,23 @@ class RetrievalApp:
|
|
| 91 |
tasks_deduped["task"].to_list(),
|
| 92 |
description_length,
|
| 93 |
repos_by_query,
|
| 94 |
-
additional_shown_cols,
|
| 95 |
)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
retrieval_pipeline = setup_retrieval_pipeline(
|
| 101 |
self.query_encoder_name,
|
| 102 |
self.document_encoder_name,
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Dict, List
|
| 2 |
|
| 3 |
+
import torch
|
| 4 |
import pandas as pd
|
|
|
|
| 5 |
import streamlit as st
|
|
|
|
| 6 |
from findkit import retrieval_pipeline
|
| 7 |
+
|
| 8 |
+
import config
|
| 9 |
from search_utils import (
|
| 10 |
+
RetrievalPipelineWrapper,
|
| 11 |
+
get_doc_cols,
|
| 12 |
get_repos_with_descriptions,
|
| 13 |
+
get_retrieval_df,
|
| 14 |
+
merge_cols,
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
| 18 |
class RetrievalApp:
|
| 19 |
+
def get_device_options(self):
|
| 20 |
+
if torch.cuda.is_available:
|
| 21 |
+
return ["cuda", "cpu"]
|
| 22 |
+
else:
|
| 23 |
+
return ["cpu"]
|
| 24 |
+
|
| 25 |
+
@st.cache(allow_output_mutation=True)
|
| 26 |
+
def get_retrieval_df(self):
|
| 27 |
+
return get_retrieval_df(self.data_path, config.text_list_cols)
|
| 28 |
+
|
| 29 |
def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
|
| 30 |
+
self.data_path = data_path
|
| 31 |
+
self.device = st.sidebar.selectbox("device", self.get_device_options())
|
| 32 |
print("loading data")
|
| 33 |
|
| 34 |
+
self.retrieval_df = self.get_retrieval_df().copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
model_name = st.sidebar.selectbox("model", config.model_names)
|
| 37 |
self.query_encoder_name = "lambdaofgod/query-" + model_name
|
| 38 |
self.document_encoder_name = "lambdaofgod/document-" + model_name
|
| 39 |
|
| 40 |
+
doc_cols = get_doc_cols(model_name)
|
| 41 |
+
|
| 42 |
st.sidebar.text("using models")
|
| 43 |
st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
|
| 44 |
+
st.sidebar.text("HTTP://huggingface.co/" + self.document_encoder_name)
|
| 45 |
+
|
| 46 |
+
self.additional_shown_cols = st.sidebar.multiselect(
|
| 47 |
+
label="used text features", options=config.text_cols, default=doc_cols
|
| 48 |
+
)
|
| 49 |
|
| 50 |
@staticmethod
|
| 51 |
def show_retrieval_results(
|
| 52 |
+
retrieval_pipe: RetrievalPipelineWrapper,
|
| 53 |
query: str,
|
| 54 |
k: int,
|
| 55 |
all_queries: List[str],
|
| 56 |
description_length: int,
|
| 57 |
repos_by_query: Dict[str, pd.DataFrame],
|
| 58 |
+
additional_shown_cols: List[str],
|
| 59 |
):
|
| 60 |
print("started retrieval")
|
| 61 |
if query in all_queries:
|
|
|
|
| 67 |
st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
|
| 68 |
with st.spinner(text="fetching results"):
|
| 69 |
st.write(
|
| 70 |
+
retrieval_pipe.search(query, k, description_length, additional_shown_cols).to_html(
|
| 71 |
escape=False, index=False
|
| 72 |
),
|
| 73 |
unsafe_allow_html=True,
|
| 74 |
)
|
| 75 |
print("finished retrieval")
|
| 76 |
|
| 77 |
+
def run_app(self, retrieval_pipeline):
|
|
|
|
| 78 |
|
| 79 |
retrieved_results = st.sidebar.number_input("number of results", value=10)
|
| 80 |
description_length = st.sidebar.number_input(
|
|
|
|
| 82 |
)
|
| 83 |
|
| 84 |
tasks_deduped = (
|
| 85 |
+
self.retrieval_df["tasks"].explode().value_counts().reset_index()
|
| 86 |
) # drop_duplicates().sort_values().reset_index(drop=True)
|
| 87 |
tasks_deduped.columns = ["task", "documents per task"]
|
| 88 |
with st.sidebar.expander("View test set queries"):
|
| 89 |
st.table(tasks_deduped.explode("task"))
|
| 90 |
+
repos_by_query = self.retrieval_df.explode("tasks").groupby("tasks")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
query = st.text_input("input query", value="metric learning")
|
| 92 |
RetrievalApp.show_retrieval_results(
|
| 93 |
retrieval_pipeline,
|
|
|
|
| 96 |
tasks_deduped["task"].to_list(),
|
| 97 |
description_length,
|
| 98 |
repos_by_query,
|
| 99 |
+
self.additional_shown_cols,
|
| 100 |
)
|
| 101 |
|
| 102 |
+
@st.cache(allow_output_mutation=True)
|
| 103 |
+
def get_retrieval_pipeline(self, displayed_retrieval_df):
|
| 104 |
+
return RetrievalPipelineWrapper.setup_from_encoder_names(
|
|
|
|
| 105 |
self.query_encoder_name,
|
| 106 |
self.document_encoder_name,
|
| 107 |
+
displayed_retrieval_df["document"],
|
| 108 |
+
displayed_retrieval_df,
|
| 109 |
+
device=self.device,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def main(self):
|
| 113 |
+
print("setting up retrieval_pipe")
|
| 114 |
+
displayed_retrieval_df = merge_cols(
|
| 115 |
+
self.retrieval_df.copy(), self.additional_shown_cols
|
| 116 |
)
|
| 117 |
+
retrieval_pipeline = self.get_retrieval_pipeline(displayed_retrieval_df)
|
| 118 |
+
self.run_app(retrieval_pipeline)
|
config.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
model_names = [
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
]
|
| 8 |
-
best_tasks_path="assets/best_tasks.csv"
|
| 9 |
-
worst_tasks_path="assets/worst_tasks.csv"
|
| 10 |
text_cols = ["dependencies", "readme", "titles"]
|
| 11 |
text_list_cols = ["titles"]
|
|
|
|
| 1 |
model_names = [
|
| 2 |
+
"dependencies-nbow-nbow-mnrl",
|
| 3 |
+
"readme-nbow-nbow-mnrl",
|
| 4 |
+
"titles-nbow-nbow-mnrl",
|
| 5 |
+
"titles_dependencies-nbow-nbow-mnrl",
|
| 6 |
+
"readme_dependencies-nbow-nbow-mnrl",
|
| 7 |
]
|
| 8 |
+
best_tasks_path = "assets/best_tasks.csv"
|
| 9 |
+
worst_tasks_path = "assets/worst_tasks.csv"
|
| 10 |
text_cols = ["dependencies", "readme", "titles"]
|
| 11 |
text_list_cols = ["titles"]
|
search_utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
from typing import Dict, List
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
import ast
|
| 6 |
import pandas as pd
|
| 7 |
import sentence_transformers
|
|
@@ -11,6 +12,33 @@ from toolz import partial
|
|
| 11 |
import config
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def truncate_description(description, length=50):
|
| 15 |
return " ".join(description.split()[:length])
|
| 16 |
|
|
@@ -19,25 +47,6 @@ def get_repos_with_descriptions(repos_df, repos):
|
|
| 19 |
return repos_df.loc[repos]
|
| 20 |
|
| 21 |
|
| 22 |
-
def search_f(
|
| 23 |
-
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
|
| 24 |
-
query: str,
|
| 25 |
-
k: int,
|
| 26 |
-
description_length: int,
|
| 27 |
-
doc_col: List[str],
|
| 28 |
-
):
|
| 29 |
-
results = retrieval_pipe.find_similar(query, k)
|
| 30 |
-
# results['repo'] = results.index
|
| 31 |
-
results["link"] = "https://github.com/" + results["repo"]
|
| 32 |
-
for col in doc_col:
|
| 33 |
-
results[col] = results[col].apply(
|
| 34 |
-
lambda desc: truncate_description(desc, description_length)
|
| 35 |
-
)
|
| 36 |
-
shown_cols = ["repo", "tasks", "link", "distance"]
|
| 37 |
-
shown_cols = shown_cols + doc_col
|
| 38 |
-
return results.reset_index(drop=True)[shown_cols]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
def merge_text_list_cols(retrieval_df, text_list_cols):
|
| 42 |
retrieval_df = retrieval_df.copy()
|
| 43 |
for col in text_list_cols:
|
|
@@ -47,29 +56,50 @@ def merge_text_list_cols(retrieval_df, text_list_cols):
|
|
| 47 |
return retrieval_df
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
documents_df: pd.DataFrame,
|
| 53 |
-
text_col: str,
|
| 54 |
-
):
|
| 55 |
-
retrieval_pipeline.RetrievalPipelineFactory.build(
|
| 56 |
-
documents_df[text_col], metadata=documents_df
|
| 57 |
-
)
|
| 58 |
|
|
|
|
| 59 |
|
| 60 |
-
@
|
| 61 |
-
def
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Dict, List
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
|
| 5 |
+
import datasets
|
| 6 |
import ast
|
| 7 |
import pandas as pd
|
| 8 |
import sentence_transformers
|
|
|
|
| 12 |
import config
|
| 13 |
|
| 14 |
|
| 15 |
+
def get_doc_cols(model_name):
|
| 16 |
+
model_name = model_name.replace("query-", "")
|
| 17 |
+
model_name = model_name.replace("document-", "")
|
| 18 |
+
return model_name.split("-")[0].split("_")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def merge_cols(df, cols):
|
| 22 |
+
df["document"] = df[cols[0]]
|
| 23 |
+
for col in cols:
|
| 24 |
+
df["document"] = df["document"] + " " + df[col]
|
| 25 |
+
return df
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_retrieval_df(
|
| 29 |
+
data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
|
| 30 |
+
):
|
| 31 |
+
raw_retrieval_df = (
|
| 32 |
+
datasets.load_dataset(data_path)["train"]
|
| 33 |
+
.to_pandas()
|
| 34 |
+
.drop_duplicates(subset=["repo"])
|
| 35 |
+
.reset_index(drop=True)
|
| 36 |
+
)
|
| 37 |
+
if text_list_cols:
|
| 38 |
+
return merge_text_list_cols(raw_retrieval_df, text_list_cols)
|
| 39 |
+
return raw_retrieval_df
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def truncate_description(description, length=50):
|
| 43 |
return " ".join(description.split()[:length])
|
| 44 |
|
|
|
|
| 47 |
return repos_df.loc[repos]
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def merge_text_list_cols(retrieval_df, text_list_cols):
|
| 51 |
retrieval_df = retrieval_df.copy()
|
| 52 |
for col in text_list_cols:
|
|
|
|
| 56 |
return retrieval_df
|
| 57 |
|
| 58 |
|
| 59 |
+
@dataclass
|
| 60 |
+
class RetrievalPipelineWrapper:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
pipeline: retrieval_pipeline.RetrievalPipeline
|
| 63 |
|
| 64 |
+
@classmethod
|
| 65 |
+
def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
|
| 66 |
+
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
|
| 67 |
+
feature_extractor=document_encoder,
|
| 68 |
+
query_feature_extractor=query_encoder,
|
| 69 |
+
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
|
| 70 |
+
)
|
| 71 |
+
pipeline = retrieval_pipe.build(documents, metadata=metadata)
|
| 72 |
+
return RetrievalPipelineWrapper(pipeline)
|
| 73 |
+
|
| 74 |
+
def search(
|
| 75 |
+
self,
|
| 76 |
+
query: str,
|
| 77 |
+
k: int,
|
| 78 |
+
description_length: int,
|
| 79 |
+
additional_shown_cols: List[str],
|
| 80 |
+
):
|
| 81 |
+
results = self.pipeline.find_similar(query, k)
|
| 82 |
+
# results['repo'] = results.index
|
| 83 |
+
results["link"] = "https://github.com/" + results["repo"]
|
| 84 |
+
for col in additional_shown_cols:
|
| 85 |
+
results[col] = results[col].apply(
|
| 86 |
+
lambda desc: truncate_description(desc, description_length)
|
| 87 |
+
)
|
| 88 |
+
shown_cols = ["repo", "tasks", "link", "distance"]
|
| 89 |
+
shown_cols = shown_cols + additional_shown_cols
|
| 90 |
+
return results.reset_index(drop=True)[shown_cols]
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
|
| 94 |
+
):
|
| 95 |
+
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
|
| 96 |
+
sentence_transformers.SentenceTransformer(
|
| 97 |
+
document_encoder_path, device=device
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
|
| 101 |
+
sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
|
| 102 |
+
)
|
| 103 |
+
return cls.build_from_encoders(
|
| 104 |
+
query_encoder, document_encoder, documents, metadata
|
| 105 |
+
)
|