|
import os |
|
from functools import lru_cache |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from qdrant_client import QdrantClient, models |
|
from sentence_transformers import SentenceTransformer |
|
|
|
load_dotenv() |
|
|
|
URL = os.getenv("QDRANT_URL") |
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") |
|
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en") |
|
|
|
print(URL) |
|
print(QDRANT_API_KEY) |
|
collection_name = "dataset_cards" |
|
client = QdrantClient( |
|
url=URL, |
|
api_key=QDRANT_API_KEY, |
|
) |
|
|
|
|
|
def format_results(results): |
|
markdown = ( |
|
"<h1 style='text-align: center;'> ✨ Dataset Search Results ✨" |
|
" </h1> \n\n" |
|
) |
|
for result in results: |
|
hub_id = result.payload["id"] |
|
download_number = result.payload["downloads"] |
|
url = f"https://huggingface.co/datasets/{hub_id}" |
|
header = f"## [{hub_id}]({url})" |
|
markdown += header + "\n" |
|
markdown += f"**Downloads:** {download_number}\n\n" |
|
markdown += f"{result.payload['section_text']} \n" |
|
|
|
return markdown |
|
|
|
|
|
@lru_cache(maxsize=100_000) |
|
def search(query: str, limit: Optional[int] = 10): |
|
query_ = sentence_embedding_model.encode( |
|
f"Represent this sentence for searching relevant passages:{query}" |
|
) |
|
results = client.search( |
|
collection_name="dataset_cards", |
|
query_vector=query_, |
|
limit=limit, |
|
) |
|
return format_results(results) |
|
|
|
|
|
@lru_cache(maxsize=100_000) |
|
def hub_id_qdrant_id(hub_id): |
|
matches = client.scroll( |
|
collection_name="dataset_cards", |
|
scroll_filter=models.Filter( |
|
must=[ |
|
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)), |
|
] |
|
), |
|
limit=1, |
|
with_payload=True, |
|
with_vectors=False, |
|
) |
|
try: |
|
return matches[0][0].id |
|
except IndexError as e: |
|
raise gr.Error( |
|
f"Hub id {hub_id} not in out database. This could be because it is very new" |
|
" or because it doesn't have much documentation." |
|
) from e |
|
|
|
|
|
@lru_cache() |
|
def recommend(hub_id, limit: Optional[int] = 10): |
|
positive_id = hub_id_qdrant_id(hub_id) |
|
results = client.recommend( |
|
collection_name=collection_name, positive=[positive_id], limit=limit |
|
) |
|
return format_results(results) |
|
|
|
|
|
def query(search_term, search_type, limit: Optional[int] = 10): |
|
if search_type == "Recommend similar datasets": |
|
return recommend(search_term, limit) |
|
else: |
|
return search(search_term, limit) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🤗 Semantic Dataset Search") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"This Gradio app allows you to search for datasets based on their" |
|
" descriptions. You can either search for similar datasets to a given" |
|
" dataset or search for datasets based on a query." |
|
) |
|
with gr.Row(): |
|
search_term = gr.Textbox( |
|
value="movie review sentiment", |
|
label="hub id i.e. IMDB or query i.e. movie review sentiment", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Row(): |
|
find_similar_btn = gr.Button("Search") |
|
search_type = gr.Radio( |
|
["Recommend similar datasets", "Semantic Search"], |
|
label="Search type", |
|
value="Semantic Search", |
|
interactive=True, |
|
) |
|
with gr.Column(): |
|
max_results = gr.Slider( |
|
minimum=1, |
|
maximum=50, |
|
step=1, |
|
value=10, |
|
label="Maximum number of results", |
|
help="This is the maximum number of results that will be returned", |
|
) |
|
results = gr.Markdown() |
|
find_similar_btn.click(query, [search_term, search_type, max_results], results) |
|
|
|
|
|
demo.launch() |
|
|