import streamlit as st
from pages import search_engine_page, document_page
from st_utils import bm25_search, semantic_search, hf_api, paginator
from huggingface_hub import ModelSearchArguments
import webbrowser
from numerize.numerize import numerize
st.set_page_config(
page_title="HuggingFace Search Engine",
page_icon="š",
layout="wide",
initial_sidebar_state="auto",
# menu_items={
# "Get Help": "https://www.extremelycoolapp.com/help",
# "Report a bug": "https://www.extremelycoolapp.com/bug",
# "About": "# This is a header. This is an *extremely* cool app!",
# },
)
### SIDEBAR
search_backend = st.sidebar.selectbox(
"Search method",
["semantic", "bm25", "hfapi"],
format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x],
)
limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
st.sidebar.markdown("# Filters")
args = ModelSearchArguments()
library = st.sidebar.multiselect(
"Library", args.library.values(), format_func=lambda x: {v: k for k, v in args.library.items()}[x]
)
task = st.sidebar.multiselect(
"Task", args.pipeline_tag.values(), format_func=lambda x: {v: k for k, v in args.pipeline_tag.items()}[x]
)
### MAIN PAGE
st.markdown(
"
šš¤ HF Search Engine
",
unsafe_allow_html=True,
)
# Search bar
search_query = st.text_input("Search for a model in HuggingFace", value="", max_chars=None, key=None, type="default")
if search_query != "":
filters = {
"library": library,
"task": task,
}
if search_backend == "hfapi":
res = hf_api(search_query, limit_results, filters)
elif search_backend == "semantic":
res = semantic_search(search_query, limit_results, filters)
elif search_backend == "bm25":
res = bm25_search(search_query, limit_results, filters)
hit_list, hits_count = res["hits"], res["count"]
hit_list = [
{
"modelId": hit["modelId"],
"tags": hit["tags"],
"downloads": hit["downloads"],
"likes": hit["likes"],
"readme": hit.get("readme", None),
}
for hit in hit_list
]
if hit_list:
st.write(f"Search results ({hits_count}):")
if hits_count > 100:
shown_results = 100
else:
shown_results = hits_count
for i, hit in paginator(
f"Select results (showing {shown_results} of {hits_count} results)",
hit_list,
):
col1, col2, col3 = st.columns([5, 1, 1])
col1.metric("Model", hit["modelId"])
col2.metric("NĀ° downloads", numerize(hit["downloads"]))
col3.metric("NĀ° likes", numerize(hit["likes"]))
st.button(
f"View model on š¤",
on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"),
key=f"{i}-{hit['modelId']}",
)
st.write(f"**Tags:** {' ā¢ '.join(hit['tags'])}")
if hit["readme"]:
with st.expander("See README"):
st.write(hit["readme"])
# TODO: embed huggingface spaces
# import streamlit.components.v1 as components
# components.html(
# f"""
#
#
#
#
# """,
# height=400,
# )
st.markdown("---")
else:
st.write(f"No Search results, please try again with different keywords")
st.markdown(
"""
""",
unsafe_allow_html=True,
)