paperswithcode_nbow / search_utils.py
lambdaofgod's picture
removed unused streamlit from utils
fdce35d
import os
from typing import Dict, List
from dataclasses import dataclass
import datasets
import ast
import pandas as pd
import sentence_transformers
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config
def get_doc_cols(model_name):
model_name = model_name.replace("query-", "")
model_name = model_name.replace("document-", "")
return model_name.split("-")[0].split("_")
def merge_cols(df, cols):
df["document"] = df[cols[0]]
for col in cols:
df["document"] = df["document"] + " " + df[col]
return df
def get_retrieval_df(
data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None
):
raw_retrieval_df = (
datasets.load_dataset(data_path)["train"]
.to_pandas()
.drop_duplicates(subset=["repo"])
.reset_index(drop=True)
)
if text_list_cols:
return merge_text_list_cols(raw_retrieval_df, text_list_cols)
return raw_retrieval_df
def truncate_description(description, length=50):
return " ".join(description.split()[:length])
def get_repos_with_descriptions(repos_df, repos):
return repos_df.loc[repos]
def merge_text_list_cols(retrieval_df, text_list_cols):
retrieval_df = retrieval_df.copy()
for col in text_list_cols:
retrieval_df[col] = retrieval_df[col].apply(
lambda t: " ".join(ast.literal_eval(t))
)
return retrieval_df
@dataclass
class RetrievalPipelineWrapper:
pipeline: retrieval_pipeline.RetrievalPipeline
@classmethod
def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata):
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
feature_extractor=document_encoder,
query_feature_extractor=query_encoder,
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
)
pipeline = retrieval_pipe.build(documents, metadata=metadata)
return RetrievalPipelineWrapper(pipeline)
def search(
self,
query: str,
k: int,
description_length: int,
additional_shown_cols: List[str],
):
results = self.pipeline.find_similar(query, k)
# results['repo'] = results.index
results["link"] = "https://github.com/" + results["repo"]
for col in additional_shown_cols:
results[col] = results[col].apply(
lambda desc: truncate_description(desc, description_length)
)
shown_cols = ["repo", "tasks", "link", "distance"]
shown_cols = shown_cols + additional_shown_cols
return results.reset_index(drop=True)[shown_cols]
@classmethod
def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device
):
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(
document_encoder_path, device=device
)
)
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(query_encoder_path, device=device)
)
return cls.build_from_encoders(
query_encoder, document_encoder, documents, metadata
)