import os from functools import lru_cache from typing import Optional import gradio as gr from dotenv import load_dotenv from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer from huggingface_hub import list_models load_dotenv() URL = os.getenv("QDRANT_URL") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en") print(URL) print(QDRANT_API_KEY) collection_name = "dataset_cards" client = QdrantClient( url=URL, api_key=QDRANT_API_KEY, ) # def convert_bytes_to_human_readable_size(bytes_size): # if bytes_size < 1024**2: # return f"{bytes_size / 1024:.2f} MB" # elif bytes_size < 1024**3: # return f"{bytes_size / (1024 ** 2):.2f} GB" # else: # return f"{bytes_size / (1024 ** 3):.2f} TB" def format_time_nicely(time_str): return time_str.split("T")[0] def format_results(results, show_associated_models=True): markdown = ( "

✨ Dataset Search Results ✨" "

\n\n" ) for result in results: hub_id = result.payload["id"] download_number = result.payload["downloads"] lastModified = result.payload["lastModified"] url = f"https://huggingface.co/datasets/{hub_id}" header = f"## [{hub_id}]({url})" markdown += header + "\n" markdown += f"**30 Day Download:** {download_number}" if lastModified: markdown += f" | **Last Modified:** {format_time_nicely(lastModified)} \n\n" else: markdown += "\n\n" markdown += f"{result.payload['section_text']} \n" if show_associated_models: if linked_models := get_models_for_dataset(hub_id): linked_models = [ f"[{model}](https://huggingface.co/{model})" for model in linked_models ] markdown += ( "
Models trained on this dataset\n\n" ) markdown += "- " + "\n- ".join(linked_models) + "\n\n" markdown += "
\n\n" return markdown @lru_cache(maxsize=100_000) def get_models_for_dataset(id): results = list(iter(list_models(filter=f"dataset:{id}"))) if results: results = list({result.id for result in results}) return results @lru_cache(maxsize=200_000) def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False): query_ = sentence_embedding_model.encode( f"Represent this sentence for searching relevant passages:{query}" ) results = client.search( collection_name="dataset_cards", query_vector=query_, limit=limit, ) return format_results(results, show_associated_models=show_linked_models) @lru_cache(maxsize=100_000) def hub_id_qdrant_id(hub_id): matches = client.scroll( collection_name="dataset_cards", scroll_filter=models.Filter( must=[ models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)), ] ), limit=1, with_payload=True, with_vectors=False, ) try: return matches[0][0].id except IndexError as e: raise gr.Error( f"Hub id {hub_id} not in the database. This could be because it is very new" " or because it doesn't have much documentation." ) from e @lru_cache() def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False): positive_id = hub_id_qdrant_id(hub_id) results = client.recommend( collection_name=collection_name, positive=[positive_id], limit=limit ) return format_results(results, show_associated_models=show_linked_models) def query( search_term, search_type, limit: Optional[int] = 10, show_linked_models: bool = False, ): if search_type == "Recommend similar datasets": return recommend(search_term, limit, show_linked_models) else: return search(search_term, limit, show_linked_models) with gr.Blocks() as demo: gr.Markdown("## 🤗 Semantic Dataset Search") with gr.Row(): gr.Markdown( "This Gradio app allows you to search for datasets based on their" " descriptions. You can either search for similar datasets to a given" " dataset or search for datasets based on a query. This is an early proof of concept. Feedback very welcome!" ) with gr.Row(): search_term = gr.Textbox( value="movie review sentiment", label="hub id i.e. IMDB or query i.e. movie review sentiment", ) with gr.Row(): with gr.Row(): find_similar_btn = gr.Button("Search") search_type = gr.Radio( ["Recommend similar datasets", "Semantic Search"], label="Search type", value="Semantic Search", interactive=True, ) with gr.Column(): max_results = gr.Slider( minimum=1, maximum=50, step=1, value=10, label="Maximum number of results", ) show_linked_models = gr.Checkbox( label="Show associated models", default=False, ) results = gr.Markdown() find_similar_btn.click( query, [search_term, search_type, max_results, show_linked_models], results ) demo.launch()