import altair as alt import gradio as gr import pandas as pd from functools import partial from datasets import load_dataset def get_data(): model_id = "ybelkada/model_cards_correct_tag" dataset = load_dataset(model_id, split="train").to_pandas() # Convert dataset to a pandas DataFrame and sort by commit_dates df = pd.DataFrame(dataset) df["commit_dates"] = pd.to_datetime(df["commit_dates"]) # Convert commit_dates to datetime format df = df.sort_values(by="commit_dates") melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type') df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100 ratio_df = df[['commit_dates', 'ratio']].copy() return ratio_df, melted_df ratio_df, melted_df = get_data() def make_plot(plot_type, refresh=False): global ratio_df, melted_df if refresh: ratio_df, melted_df = get_data() if plot_type == "Total models with missing 'transformers' tag": highlight = alt.selection(type='single', on='mouseover', fields=['type'], nearest=True) base = alt.Chart(melted_df).encode( x=alt.X('commit_dates:T', title='Date'), y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"), color='type:N', ) points = base.mark_circle().encode( opacity=alt.value(1), ).add_selection( highlight ).properties( width=1200, height=800, ) lines = base.mark_line().encode( size=alt.condition(~highlight, alt.value(1), alt.value(3)) ) return points + lines else: highlight = alt.selection(type='single', on='mouseover', fields=['ratio'], nearest=True) base = alt.Chart(ratio_df).encode( x=alt.X('commit_dates:T', title='Date'), y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"), ) points = base.mark_circle().encode( opacity=alt.value(1) ).add_selection( highlight ).properties( width=1200, height=800, ) lines = base.mark_line().encode( size=alt.condition(~highlight, alt.value(1), alt.value(3)) ) return points + lines with gr.Blocks() as demo: button = gr.Radio( label="Plot type", choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], value="Total models with missing 'transformers' tag" ) refresh_button = gr.Button(value="Fetch latest data") plot = gr.Plot(label="Plot") button.change(make_plot, inputs=[button], outputs=[plot]) refresh_button.click(partial(make_plot, True), inputs=[button], outputs=[plot]) demo.load(make_plot, inputs=[button], outputs=[plot]) if __name__ == "__main__": demo.launch()