ybelkada's picture
Update app.py
97db185 verified
raw
history blame
3.25 kB
import altair as alt
import gradio as gr
import pandas as pd
from functools import partial
from datasets import load_dataset
def get_data():
model_id = "ybelkada/model_cards_correct_tag"
dataset = load_dataset(model_id, split="train").to_pandas()
# Convert dataset to a pandas DataFrame and sort by commit_dates
df = pd.DataFrame(dataset)
df["commit_dates"] = pd.to_datetime(df["commit_dates"]) # Convert commit_dates to datetime format
df = df.sort_values(by="commit_dates")
melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type')
df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100
ratio_df = df[['commit_dates', 'ratio']].copy()
return ratio_df, melted_df
ratio_df, melted_df = get_data()
def make_plot(plot_type, refresh=False):
global ratio_df, melted_df
if refresh:
ratio_df, melted_df = get_data()
if plot_type == "Total models with missing 'transformers' tag":
highlight = alt.selection(type='single', on='mouseover',
fields=['type'], nearest=True)
base = alt.Chart(melted_df).encode(
x=alt.X('commit_dates:T', title='Date'),
y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"),
color='type:N',
)
points = base.mark_circle().encode(
opacity=alt.value(1),
).add_selection(
highlight
).properties(
width=1200,
height=800,
)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)
return points + lines
else:
highlight = alt.selection(type='single', on='mouseover',
fields=['ratio'], nearest=True)
base = alt.Chart(ratio_df).encode(
x=alt.X('commit_dates:T', title='Date'),
y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"),
)
points = base.mark_circle().encode(
opacity=alt.value(1)
).add_selection(
highlight
).properties(
width=1200,
height=800,
)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)
return points + lines
with gr.Blocks() as demo:
button = gr.Radio(
label="Plot type",
choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"],
value="Total models with missing 'transformers' tag"
)
refresh_button = gr.Button(value="Fetch latest data")
plot = gr.Plot(label="Plot")
button.change(make_plot, inputs=[button], outputs=[plot])
refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot])
demo.load(make_plot, inputs=[button], outputs=[plot])
if __name__ == "__main__":
demo.launch()