File size: 3,871 Bytes
0e05863
 
 
d362bcf
0e05863
 
 
48bcb8a
 
 
 
d362bcf
0e05863
48bcb8a
0e05863
 
 
48bcb8a
0e05863
 
 
 
 
 
48bcb8a
0e05863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48bcb8a
0e05863
 
 
48bcb8a
0e05863
48bcb8a
0e05863
 
 
 
 
 
 
 
 
 
 
 
 
 
48bcb8a
0e05863
 
 
48bcb8a
0e05863
 
48bcb8a
0e05863
 
 
 
7ba815c
 
0e05863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments
from pprint import pprint
from hf_search import HFSearch
import streamlit as st
import itertools

from pbr.version import VersionInfo
print("hf_search version:", VersionInfo('hf_search').version_string())

hf_search = HFSearch(top_k=1000)

@st.cache
def hf_api(query, limit=5, sort=None, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)
    print("sort", sort)

    api = HfApi()
    filt = ModelFilter(
        task=filters["task"],
        library=filters["library"],
    )
    models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True)
    hits = []
    for model in models:
        model = model.__dict__
        hits.append(
            {
                "modelId": model.get("modelId"),
                "tags": model.get("tags"),
                "downloads": model.get("downloads"),
                "likes": model.get("likes"),
            }
        )
    count = len(hits)
    if len(hits) > limit:
        hits = hits[:limit]
    return {"hits": hits, "count": count}


@st.cache
def semantic_search(query, limit=5, sort=None, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)
    print("sort", sort)

    hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters)
    hits = [
        {
            "modelId": hit["modelId"],
            "tags": hit["tags"],
            "downloads": hit["downloads"],
            "likes": hit["likes"],
            "readme": hit.get("readme", None),
        }
        for hit in hits
    ]
    return {"hits": hits, "count": len(hits)}


@st.cache
def bm25_search(query, limit=5, sort=None, filters={}):
    print("query", query)
    print("filters", filters)
    print("limit", limit)
    print("sort", sort)

    # TODO: filters
    hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters)
    hits = [
        {
            "modelId": hit["modelId"],
            "tags": hit["tags"],
            "downloads": hit["downloads"],
            "likes": hit["likes"],
            "readme": hit.get("readme", None),
        }
        for hit in hits
    ]
    hits = [
        hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]]
    ]  # unique hits
    return {"hits": hits, "count": len(hits)}


def paginator(label, articles, articles_per_page=10, on_sidebar=True):
    # https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7
    """Lets the user paginate a set of article.
    Parameters
    ----------
    label : str
        The label to display over the pagination widget.
    article : Iterator[Any]
        The articles to display in the paginator.
    articles_per_page: int
        The number of articles to display per page.
    on_sidebar: bool
        Whether to display the paginator widget on the sidebar.

    Returns
    -------
    Iterator[Tuple[int, Any]]
        An iterator over *only the article on that page*, including
        the item's index.
    """

    # Figure out where to display the paginator
    if on_sidebar:
        location = st.sidebar.empty()
    else:
        location = st.empty()

    # Display a pagination selectbox in the specified location.
    articles = list(articles)
    n_pages = (len(articles) - 1) // articles_per_page + 1
    page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}"
    page_number = location.selectbox(label, range(n_pages), format_func=page_format_func)

    # Iterate over the articles in the page to let the user display them.
    min_index = page_number * articles_per_page
    max_index = min_index + articles_per_page

    return itertools.islice(enumerate(articles), min_index, max_index)