File size: 3,731 Bytes
568499b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from typing import Dict, List

import pandas as pd
import datasets
import streamlit as st
import config
from findkit import retrieval_pipeline
from search_utils import (
    get_repos_with_descriptions,
    search_f,
    merge_text_list_cols,
    setup_retrieval_pipeline,
)


class RetrievalApp:
    def __init__(self, data_path="lambdaofgod/pwc_repositories_with_dependencies"):
        print("loading data")

        raw_retrieval_df = (
            datasets.load_dataset(data_path)["train"]
            .to_pandas()
            .drop_duplicates(subset=["repo"])
            .reset_index(drop=True)
        )
        self.retrieval_df = merge_text_list_cols(
            raw_retrieval_df, config.text_list_cols
        )

        model_name = st.sidebar.selectbox("model", config.model_names)
        self.query_encoder_name = "lambdaofgod/query-" + model_name
        self.document_encoder_name = "lambdaofgod/document-" + model_name

        st.sidebar.text("using models")
        st.sidebar.text("https://huggingface.co/" + self.query_encoder_name)
        st.sidebar.text("https://huggingface.co/" + self.document_encoder_name)

    @staticmethod
    def show_retrieval_results(
        retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
        query: str,
        k: int,
        all_queries: List[str],
        description_length: int,
        repos_by_query: Dict[str, pd.DataFrame],
        doc_col: str,
    ):
        print("started retrieval")
        if query in all_queries:
            with st.expander(
                "query is in gold standard set queries. Toggle viewing gold standard results?"
            ):
                st.write("gold standard results")
                task_repos = repos_by_query.get_group(query)
                st.table(get_repos_with_descriptions(retrieval_pipe.X_df, task_repos))
        with st.spinner(text="fetching results"):
            st.write(
                search_f(retrieval_pipe, query, k, description_length, doc_col).to_html(
                    escape=False, index=False
                ),
                unsafe_allow_html=True,
            )
        print("finished retrieval")

    @staticmethod
    def app(retrieval_pipeline, retrieval_df, doc_col):

        retrieved_results = st.sidebar.number_input("number of results", value=10)
        description_length = st.sidebar.number_input(
            "number of used description words", value=10
        )

        tasks_deduped = (
            retrieval_df["tasks"].explode().value_counts().reset_index()
        )  # drop_duplicates().sort_values().reset_index(drop=True)
        tasks_deduped.columns = ["task", "documents per task"]
        with st.sidebar.expander("View test set queries"):
            st.table(tasks_deduped.explode("task"))

        additional_shown_cols = st.sidebar.multiselect(
            label="additional cols", options=config.text_cols, default=doc_col
        )

        repos_by_query = retrieval_df.explode("tasks").groupby("tasks")
        query = st.text_input("input query", value="metric learning")
        RetrievalApp.show_retrieval_results(
            retrieval_pipeline,
            query,
            retrieved_results,
            tasks_deduped["task"].to_list(),
            description_length,
            repos_by_query,
            additional_shown_cols,
        )

    def main(self):
        print("setting up retrieval_pipe")
        doc_col = "dependencies"
        retrieval_pipeline = setup_retrieval_pipeline(
            self.query_encoder_name,
            self.document_encoder_name,
            self.retrieval_df[doc_col],
            self.retrieval_df,
        )
        RetrievalApp.app(retrieval_pipeline, self.retrieval_df, doc_col)