|
import os |
|
from functools import lru_cache |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from qdrant_client import QdrantClient, models |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import list_models |
|
|
|
load_dotenv() |
|
|
|
URL = os.getenv("QDRANT_URL") |
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") |
|
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en") |
|
|
|
print(URL) |
|
print(QDRANT_API_KEY) |
|
collection_name = "dataset_cards" |
|
client = QdrantClient( |
|
url=URL, |
|
api_key=QDRANT_API_KEY, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_time_nicely(time_str): |
|
return time_str.split("T")[0] |
|
|
|
|
|
def format_results(results, show_associated_models=True): |
|
markdown = ( |
|
"<h1 style='text-align: center;'> ✨ Dataset Search Results ✨" |
|
" </h1> \n\n" |
|
) |
|
for result in results: |
|
hub_id = result.payload["id"] |
|
download_number = result.payload["downloads"] |
|
lastModified = result.payload["lastModified"] |
|
url = f"https://huggingface.co/datasets/{hub_id}" |
|
header = f"## [{hub_id}]({url})" |
|
markdown += header + "\n" |
|
|
|
markdown += f"**30 Day Download:** {download_number}" |
|
if lastModified: |
|
markdown += f" | **Last Modified:** {format_time_nicely(lastModified)} \n\n" |
|
else: |
|
markdown += "\n\n" |
|
markdown += f"{result.payload['section_text']} \n" |
|
if show_associated_models: |
|
if linked_models := get_models_for_dataset(hub_id): |
|
linked_models = [ |
|
f"[{model}](https://huggingface.co/{model})" |
|
for model in linked_models |
|
] |
|
markdown += ( |
|
"<details><summary>Models trained on this dataset</summary>\n\n" |
|
) |
|
markdown += "- " + "\n- ".join(linked_models) + "\n\n" |
|
markdown += "</details>\n\n" |
|
|
|
return markdown |
|
|
|
|
|
@lru_cache(maxsize=100_000) |
|
def get_models_for_dataset(id): |
|
results = list(iter(list_models(filter=f"dataset:{id}"))) |
|
if results: |
|
results = list({result.id for result in results}) |
|
return results |
|
|
|
|
|
@lru_cache(maxsize=200_000) |
|
def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False): |
|
query_ = sentence_embedding_model.encode( |
|
f"Represent this sentence for searching relevant passages:{query}" |
|
) |
|
results = client.search( |
|
collection_name="dataset_cards", |
|
query_vector=query_, |
|
limit=limit, |
|
) |
|
return format_results(results, show_associated_models=show_linked_models) |
|
|
|
|
|
@lru_cache(maxsize=100_000) |
|
def hub_id_qdrant_id(hub_id): |
|
matches = client.scroll( |
|
collection_name="dataset_cards", |
|
scroll_filter=models.Filter( |
|
must=[ |
|
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)), |
|
] |
|
), |
|
limit=1, |
|
with_payload=True, |
|
with_vectors=False, |
|
) |
|
try: |
|
return matches[0][0].id |
|
except IndexError as e: |
|
raise gr.Error( |
|
f"Hub id {hub_id} not in the database. This could be because it is very new" |
|
" or because it doesn't have much documentation." |
|
) from e |
|
|
|
|
|
@lru_cache() |
|
def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False): |
|
positive_id = hub_id_qdrant_id(hub_id) |
|
results = client.recommend( |
|
collection_name=collection_name, positive=[positive_id], limit=limit |
|
) |
|
return format_results(results, show_associated_models=show_linked_models) |
|
|
|
|
|
def query( |
|
search_term, |
|
search_type, |
|
limit: Optional[int] = 10, |
|
show_linked_models: bool = False, |
|
): |
|
if search_type == "Recommend similar datasets": |
|
return recommend(search_term, limit, show_linked_models) |
|
else: |
|
return search(search_term, limit, show_linked_models) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🤗 Semantic Dataset Search") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"This Gradio app allows you to search for datasets based on their" |
|
" descriptions. You can either search for similar datasets to a given" |
|
" dataset or search for datasets based on a query. This is an early proof of concept. Feedback very welcome!" |
|
) |
|
with gr.Row(): |
|
search_term = gr.Textbox( |
|
value="movie review sentiment", |
|
label="hub id i.e. IMDB or query i.e. movie review sentiment", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Row(): |
|
find_similar_btn = gr.Button("Search") |
|
search_type = gr.Radio( |
|
["Recommend similar datasets", "Semantic Search"], |
|
label="Search type", |
|
value="Semantic Search", |
|
interactive=True, |
|
) |
|
with gr.Column(): |
|
max_results = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=10, |
|
label="Maximum number of results", |
|
) |
|
show_linked_models = gr.Checkbox( |
|
label="Show associated models", |
|
default=False, |
|
) |
|
|
|
results = gr.Markdown() |
|
find_similar_btn.click( |
|
query, [search_term, search_type, max_results, show_linked_models], results |
|
) |
|
|
|
|
|
demo.launch() |
|
|