davanstrien's picture
davanstrien HF staff
feedack mention
a064f46
raw
history blame
5.62 kB
import os
from functools import lru_cache
from typing import Optional
import gradio as gr
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from huggingface_hub import list_models
load_dotenv()
URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en")
print(URL)
print(QDRANT_API_KEY)
collection_name = "dataset_cards"
client = QdrantClient(
url=URL,
api_key=QDRANT_API_KEY,
)
# def convert_bytes_to_human_readable_size(bytes_size):
# if bytes_size < 1024**2:
# return f"{bytes_size / 1024:.2f} MB"
# elif bytes_size < 1024**3:
# return f"{bytes_size / (1024 ** 2):.2f} GB"
# else:
# return f"{bytes_size / (1024 ** 3):.2f} TB"
def format_time_nicely(time_str):
return time_str.split("T")[0]
def format_results(results, show_associated_models=True):
markdown = (
"<h1 style='text-align: center;'> &#x2728; Dataset Search Results &#x2728;"
" </h1> \n\n"
)
for result in results:
hub_id = result.payload["id"]
download_number = result.payload["downloads"]
lastModified = result.payload["lastModified"]
url = f"https://huggingface.co/datasets/{hub_id}"
header = f"## [{hub_id}]({url})"
markdown += header + "\n"
markdown += f"**30 Day Download:** {download_number}"
if lastModified:
markdown += f" | **Last Modified:** {format_time_nicely(lastModified)} \n\n"
else:
markdown += "\n\n"
markdown += f"{result.payload['section_text']} \n"
if show_associated_models:
if linked_models := get_models_for_dataset(hub_id):
linked_models = [
f"[{model}](https://huggingface.co/{model})"
for model in linked_models
]
markdown += (
"<details><summary>Models trained on this dataset</summary>\n\n"
)
markdown += "- " + "\n- ".join(linked_models) + "\n\n"
markdown += "</details>\n\n"
return markdown
@lru_cache(maxsize=100_000)
def get_models_for_dataset(id):
results = list(iter(list_models(filter=f"dataset:{id}")))
if results:
results = list({result.id for result in results})
return results
@lru_cache(maxsize=200_000)
def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False):
query_ = sentence_embedding_model.encode(
f"Represent this sentence for searching relevant passages:{query}"
)
results = client.search(
collection_name="dataset_cards",
query_vector=query_,
limit=limit,
)
return format_results(results, show_associated_models=show_linked_models)
@lru_cache(maxsize=100_000)
def hub_id_qdrant_id(hub_id):
matches = client.scroll(
collection_name="dataset_cards",
scroll_filter=models.Filter(
must=[
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)),
]
),
limit=1,
with_payload=True,
with_vectors=False,
)
try:
return matches[0][0].id
except IndexError as e:
raise gr.Error(
f"Hub id {hub_id} not in the database. This could be because it is very new"
" or because it doesn't have much documentation."
) from e
@lru_cache()
def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False):
positive_id = hub_id_qdrant_id(hub_id)
results = client.recommend(
collection_name=collection_name, positive=[positive_id], limit=limit
)
return format_results(results, show_associated_models=show_linked_models)
def query(
search_term,
search_type,
limit: Optional[int] = 10,
show_linked_models: bool = False,
):
if search_type == "Recommend similar datasets":
return recommend(search_term, limit, show_linked_models)
else:
return search(search_term, limit, show_linked_models)
with gr.Blocks() as demo:
gr.Markdown("## &#129303; Semantic Dataset Search")
with gr.Row():
gr.Markdown(
"This Gradio app allows you to search for datasets based on their"
" descriptions. You can either search for similar datasets to a given"
" dataset or search for datasets based on a query. This is an early proof of concept. Feedback very welcome!"
)
with gr.Row():
search_term = gr.Textbox(
value="movie review sentiment",
label="hub id i.e. IMDB or query i.e. movie review sentiment",
)
with gr.Row():
with gr.Row():
find_similar_btn = gr.Button("Search")
search_type = gr.Radio(
["Recommend similar datasets", "Semantic Search"],
label="Search type",
value="Semantic Search",
interactive=True,
)
with gr.Column():
max_results = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Maximum number of results",
)
show_linked_models = gr.Checkbox(
label="Show associated models",
default=False,
)
results = gr.Markdown()
find_similar_btn.click(
query, [search_term, search_type, max_results, show_linked_models], results
)
demo.launch()