import json from huggingface_hub import HfApi, ModelFilter, DatasetFilter, ModelSearchArguments from pprint import pprint from hf_search import HFSearch import streamlit as st import itertools from pbr.version import VersionInfo print("hf_search version:", VersionInfo('hf_search').version_string()) hf_search = HFSearch(top_k=200) @st.cache def hf_api(query, limit=5, sort=None, filters={}): print("query", query) print("filters", filters) print("limit", limit) print("sort", sort) api = HfApi() filt = ModelFilter( task=filters["task"], library=filters["library"], ) models = api.list_models(search=query, filter=filt, limit=limit, sort=sort, full=True) hits = [] for model in models: model = model.__dict__ hits.append( { "modelId": model.get("modelId"), "tags": model.get("tags"), "downloads": model.get("downloads"), "likes": model.get("likes"), } ) count = len(hits) if len(hits) > limit: hits = hits[:limit] return {"hits": hits, "count": count} @st.cache def semantic_search(query, limit=5, sort=None, filters={}): print("query", query) print("filters", filters) print("limit", limit) print("sort", sort) hits = hf_search.search(query=query, method="retrieve & rerank", limit=limit, sort=sort, filters=filters) hits = [ { "modelId": hit["modelId"], "tags": hit["tags"], "downloads": hit["downloads"], "likes": hit["likes"], "readme": hit.get("readme", None), } for hit in hits ] return {"hits": hits, "count": len(hits)} @st.cache def bm25_search(query, limit=5, sort=None, filters={}): print("query", query) print("filters", filters) print("limit", limit) print("sort", sort) # TODO: filters hits = hf_search.search(query=query, method="bm25", limit=limit, sort=sort, filters=filters) hits = [ { "modelId": hit["modelId"], "tags": hit["tags"], "downloads": hit["downloads"], "likes": hit["likes"], "readme": hit.get("readme", None), } for hit in hits ] hits = [ hits[i] for i in range(len(hits)) if hits[i]["modelId"] not in [h["modelId"] for h in hits[:i]] ] # unique hits return {"hits": hits, "count": len(hits)} def paginator(label, articles, articles_per_page=10, on_sidebar=True): # https://gist.github.com/treuille/2ce0acb6697f205e44e3e0f576e810b7 """Lets the user paginate a set of article. Parameters ---------- label : str The label to display over the pagination widget. article : Iterator[Any] The articles to display in the paginator. articles_per_page: int The number of articles to display per page. on_sidebar: bool Whether to display the paginator widget on the sidebar. Returns ------- Iterator[Tuple[int, Any]] An iterator over *only the article on that page*, including the item's index. """ # Figure out where to display the paginator if on_sidebar: location = st.sidebar.empty() else: location = st.empty() # Display a pagination selectbox in the specified location. articles = list(articles) n_pages = (len(articles) - 1) // articles_per_page + 1 page_format_func = lambda i: f"Results {i*10} to {i*10 +10 -1}" page_number = location.selectbox(label, range(n_pages), format_func=page_format_func) # Iterate over the articles in the page to let the user display them. min_index = page_number * articles_per_page max_index = min_index + articles_per_page return itertools.islice(enumerate(articles), min_index, max_index)