File size: 3,193 Bytes
abe7f1c
 
 
 
 
 
 
 
 
 
f37cf2e
 
 
 
 
 
 
 
 
 
abe7f1c
 
 
 
 
 
 
f37cf2e
abe7f1c
 
 
 
 
 
 
 
7fec6b3
f37cf2e
 
 
7fec6b3
abe7f1c
 
f37cf2e
abe7f1c
f37cf2e
 
abe7f1c
f37cf2e
 
 
 
 
 
abe7f1c
f37cf2e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
## LIBRARIES ###
from cProfile import label
from tkinter import font
from turtle import width
import streamlit as st
import pandas as pd
from datetime import datetime
import plotly.express as px


def select_plot_data(df, quantile_low, qunatile_high):
    df.fillna(0, inplace=True)
    df_plot = df.set_index('Model').T
    df_plot.index = date_range(df_plot)
    df_stats = df_plot.describe()
    quantile_lvalue = df_stats.quantile(quantile_low, axis=1)['mean']
    quantile_hvalue = df_stats.quantile(qunatile_high, axis=1)['mean']  
    df_plot_data = df_plot.loc[:,[(df_plot[col].mean() > quantile_lvalue and df_plot[col].mean() < quantile_hvalue) for col in df_plot.columns]]
    return df_plot_data

def read_file_to_df(file):
   return pd.read_csv(file)

def date_range(df):
    time = df.index.to_list()
    time_range = []
    for t in time:
        time_range.append(str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().month) +'/' +  str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().day)  + '/' + str(datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%fZ').date().year)[-2:])
    return time_range


if __name__ == "__main__":
    ### STREAMLIT APP CONGFIG ###
    st.set_page_config(layout="wide", page_title="HF Hub Model Usage Visualization")

    st.header("Model Usage Visualization")
    with st.expander("How to read and interact with the plot:"):
        st.markdown("The plots below visualize weekly usage for HF models categorized by the model creation time.")
        st.markdown("Select the model creation time range you want to visualize using the dropdown menu below.")
        st.markdown("Choose the quantile range to filter out models with high or low usage.")
        st.markdown("The plots are interactive. Hover over the points to see the model name and the number of weekly mean usage. Click on the legend to hide/show the models.")


    model_init_year = st.multiselect("Model creation year", ["before_2021", "2021", "2022"], key = "model_init_year", default = "2022")

    popularity_low = st.slider("Model popularity quantile (lower limit) ",  min_value=0.0, max_value=1.0, step=0.01, value=0.90, key = "popularity_low")
    popularity_high = st.slider("Model popularity quantile (upper limit) ",  min_value=0.0, max_value=1.0, step=0.01, value=0.99, key = "popularity_high")

    if 'model_init_year' not in st.session_state:
        st.session_state['model_init_year'] = model_init_year
    if 'popularity_low' not in st.session_state:
        st.session_state['popularity_low'] = popularity_low
    if 'popularity_high' not in st.session_state:
        st.session_state['popularity_high'] = popularity_high

    with st.container():
        for year in st.session_state['model_init_year']:
            plotly_spot = st.empty()
            df = read_file_to_df("./assets/"+year+"/model_usage.csv")
            df_plot_data = select_plot_data(df, st.session_state['popularity_low'], st.session_state['popularity_high'])
            fig = px.line(df_plot_data, title="Models created in "+year, labels={"index": "Weeks", "value": "Usage", "variable": "Model"})
            with plotly_spot:
                st.plotly_chart(fig, use_container_width=True)