ybelkada's picture
Update app.py
1a3c9e6 verified
raw
history blame
2.69 kB
from datasets import load_dataset
from huggingface_hub import ModelCard
from huggingface_hub import HfApi
import gradio as gr
import pandas as pd
api = HfApi()
repo_id = "librarian-bots/model_cards_with_metadata"
dataset = load_dataset(repo_id, split='train')
dataset = dataset.filter(lambda x: x['library_name']=='transformers')
list_commits = api.list_repo_commits(repo_id, repo_type="dataset")
commits_date_dict = {commit.created_at.strftime("%m/%d/%Y"):commit.commit_id for commit in list_commits}
current_date = "latest"
def get_data(commit_date="latest"):
ds_kwargs = {}
if commit_date != "latest":
current_date = commit_date
commit_id = commits_date_dict[commit_date]
ds_kwargs = {"revision": commit_id}
dataset = load_dataset(repo_id, split='train', **ds_kwargs)
dataset = dataset.filter(lambda x: x['library_name']=='transformers')
def pipeline_tag_not_in_card(card):
try:
model_card_data = ModelCard(card).data
if model_card_data.library_name is None:
return True
return False
except AttributeError:
return False
except Exception:
return False
ds = dataset.map(lambda x: {"missing_library_name": pipeline_tag_not_in_card(x['card'])}, num_proc=4)
data = pd.DataFrame(
{
"name": ["Total Number of transformers Model", "Total number of models with missing 'library_name: transformers' in model card."],
"count": [len(ds), sum(ds["missing_library_name"])],
}
)
return data
def fetch_fn(commit_date="latest"):
data = get_data(commit_date=commit_date)
return gr.BarPlot(
data,
x="name",
y="count",
title="Count of Model cards with the correct library_name tag",
height=256,
width=1024,
tooltip=["name", "count"],
vertical=False
)
data = get_data()
with gr.Blocks() as bar_plot:
with gr.Column():
with gr.Row():
plot = gr.BarPlot(
data,
x="name",
y="count",
title=f"Count of Model cards with the correct library_name tag at the date {current_date}",
height=256,
width=1024,
tooltip=["name", "count"],
vertical=False
)
with gr.Column():
display = gr.Dropdown(
choices=list(commits_date_dict.keys()),
value="latest",
label="Type of Bar Plot",
)
display.change(fetch_fn, inputs=display, outputs=plot)
bar_plot.launch()