davidberenstein1957's picture
Update pipeline explorer
54c440a
raw
history blame
5.57 kB
import asyncio
import urllib
from typing import Iterable
import gradio as gr
import markdown as md
import pandas as pd
from distilabel.cli.pipeline.utils import _build_pipeline_panel, get_pipeline
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from gradio_leaderboard import ColumnFilter, Leaderboard, SearchColumns, SelectColumns
from gradio_modal import Modal
from huggingface_hub import HfApi, HfFileSystem, RepoCard
from huggingface_hub.hf_api import DatasetInfo
# Initialize the Hugging Face API
api = HfApi()
example = HuggingfaceHubSearch().example_value()
fs = HfFileSystem()
def _categorize_dtypes(df):
dtype_mapping = {
'int64': 'number',
'float64': 'number',
'bool': 'bool',
'datetime64[ns]': 'date',
'datetime64[ns, UTC]': 'date',
'object': 'str'
}
categorized_dtypes = []
for column, dtype in df.dtypes.items():
dtype_str = str(dtype)
if dtype_str in dtype_mapping:
categorized_dtypes.append(dtype_mapping[dtype_str])
else:
categorized_dtypes.append('markdown')
return categorized_dtypes
def _get_tag_category(entry: list[str], tag_category: str):
for item in entry:
if tag_category in item:
return item.split(f"{tag_category}:")[-1]
else:
return None
def _has_pipeline(repo_id):
file_path = f"datasets/{repo_id}/pipeline.log"
url = "https://huggingface.co/{file_path}"
if fs.exists(file_path):
pipeline = get_pipeline(url)
return str(_build_pipeline_panel(pipeline))
else:
return ""
async def check_pipelines(repo_ids):
tasks = [_has_pipeline(fs, repo_id) for repo_id in repo_ids]
results = await asyncio.gather(*tasks)
return dict(zip(repo_ids, results))
def _search_distilabel_repos(query: str = None,):
filter = "library:distilabel"
if query:
filter = f"{filter}&search={urllib.urlencode(query)}"
datasets: Iterable[DatasetInfo] = api.list_datasets(filter=filter)
data = [ex.__dict__ for ex in datasets]
df = pd.DataFrame.from_records(data)
df["size_categories"] = df.tags.apply(_get_tag_category, args=["size_categories"])
# df["has_pipeline"] = asyncio.run(check_pipelines(df.id.tolist()))
df["has_pipeline"] = ""
subset_columns = ['id', 'likes', 'downloads', "size_categories", 'has_pipeline', 'last_modified', 'description']
new_column_order = subset_columns + [col for col in df.columns if col not in subset_columns]
df = df[new_column_order]
return df
def _create_modal_info(row: dict) -> str:
def _get_main_title(repo_id):
return f'<h1> <a href="https://huggingface.co/datasets/{repo_id}">{repo_id}</a> </h1>'
def _embed_dataset_viewer(repo_id):
return (
f"""<iframe src="https://huggingface.co/datasets/{repo_id}/embed/viewer" frameborder="0" width="100%" height="560px"></iframe>"""
)
def _get_dataset_card(repo_id):
return md.markdown(RepoCard.load(repo_id_or_path=repo_id, repo_type="dataset").text)
return "<br>".join([
_get_main_title(repo_id=row["id"]),
f'pipeline available: {_has_pipeline(repo_id=row["id"])}',
_embed_dataset_viewer(repo_id=row["id"]),
_get_dataset_card(repo_id=row["id"]),
])
# Define the Gradio interface
with gr.Blocks(delete_cache=[1,1]) as demo:
gr.Markdown("# ⚗️ Distilabel Synthetic Data Pipeline Finder")
gr.HTML("Select a dataset to show the pipeline, dataset viewer and model card.")
df: pd.DataFrame = _search_distilabel_repos()
leader_board = Leaderboard(
value=df,
datatype=_categorize_dtypes(df),
search_columns=SearchColumns(primary_column="id", secondary_columns=["description", "author"],
placeholder="Search by id, description or author. To search by description or author, type 'description:<query>', 'author:<query>'",
label="Search"),
filter_columns=[
ColumnFilter("likes", type="slider", min=0, max=df.likes.max(), default=[0, df.likes.max()]),
ColumnFilter("downloads", type="slider", min=0, max=df.downloads.max(), default=[0, df.downloads.max()]),
ColumnFilter("size_categories", type="checkboxgroup"),
ColumnFilter("has_pipeline", type="checkboxgroup"),
],
hide_columns=[
"_id", "private", "gated", "disabled", "sha", "downloads_all_time", "paperswithcode_id", "tags", "siblings",
"cardData", "lastModified", "card_data", "key"],
select_columns=SelectColumns(default_selection=["id", "last_modified", "downloads", "likes", "size_categories"],
cant_deselect=["id"],
label="Select The Columns",
info="Helpful information"),
)
with Modal() as modal:
markdown = gr.HTML(value="test")
def update(leader_board, markdown, evt: gr.SelectData):
if not isinstance(evt.index, int):
index = evt.index[0] # Assuming evt.index is a list or similar structure
markdown = _create_modal_info(row=leader_board.iloc[index].to_dict())
modal = Modal(visible=True)
return leader_board, markdown, modal
else:
return leader_board, markdown
leader_board.select(update, [leader_board, markdown], [leader_board, markdown, modal], show_progress="hidden")
if __name__ == "__main__":
demo.launch()