import os from typing import Dict, List from dataclasses import dataclass import datasets import ast import pandas as pd import sentence_transformers from findkit import feature_extractors, indexes, retrieval_pipeline from toolz import partial import config def get_doc_cols(model_name): model_name = model_name.replace("query-", "") model_name = model_name.replace("document-", "") return model_name.split("-")[0].split("_") def merge_cols(df, cols): df["document"] = df[cols[0]] for col in cols: df["document"] = df["document"] + " " + df[col] return df def get_retrieval_df( data_path="lambdaofgod/pwc_repositories_with_dependencies", text_list_cols=None ): raw_retrieval_df = ( datasets.load_dataset(data_path)["train"] .to_pandas() .drop_duplicates(subset=["repo"]) .reset_index(drop=True) ) if text_list_cols: return merge_text_list_cols(raw_retrieval_df, text_list_cols) return raw_retrieval_df def truncate_description(description, length=50): return " ".join(description.split()[:length]) def get_repos_with_descriptions(repos_df, repos): return repos_df.loc[repos] def merge_text_list_cols(retrieval_df, text_list_cols): retrieval_df = retrieval_df.copy() for col in text_list_cols: retrieval_df[col] = retrieval_df[col].apply( lambda t: " ".join(ast.literal_eval(t)) ) return retrieval_df @dataclass class RetrievalPipelineWrapper: pipeline: retrieval_pipeline.RetrievalPipeline @classmethod def build_from_encoders(cls, query_encoder, document_encoder, documents, metadata): retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory( feature_extractor=document_encoder, query_feature_extractor=query_encoder, index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"), ) pipeline = retrieval_pipe.build(documents, metadata=metadata) return RetrievalPipelineWrapper(pipeline) def search( self, query: str, k: int, description_length: int, additional_shown_cols: List[str], ): results = self.pipeline.find_similar(query, k) # results['repo'] = results.index results["link"] = "https://github.com/" + results["repo"] for col in additional_shown_cols: results[col] = results[col].apply( lambda desc: truncate_description(desc, description_length) ) shown_cols = ["repo", "tasks", "link", "distance"] shown_cols = shown_cols + additional_shown_cols return results.reset_index(drop=True)[shown_cols] @classmethod def setup_from_encoder_names(cls, query_encoder_path, document_encoder_path, documents, metadata, device ): document_encoder = feature_extractors.SentenceEncoderFeatureExtractor( sentence_transformers.SentenceTransformer( document_encoder_path, device=device ) ) query_encoder = feature_extractors.SentenceEncoderFeatureExtractor( sentence_transformers.SentenceTransformer(query_encoder_path, device=device) ) return cls.build_from_encoders( query_encoder, document_encoder, documents, metadata )