Spaces:
Runtime error
Runtime error
import json | |
from typing import Tuple | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.figure_factory as ff | |
import plotly.graph_objects as go | |
import streamlit as st | |
from plotly.subplots import make_subplots | |
from exp_utils import MODELS | |
from visualize_utils import viridis_rgb | |
st.set_page_config( | |
page_title="Results Viewer", | |
page_icon="📊", | |
initial_sidebar_state="expanded", | |
layout="wide", | |
) | |
MODELS_SIZE_MAPPING = {k: v["model_size"] for k, v in MODELS.items()} | |
MODELS_FAMILY_MAPPING = {k: v["model_family"] for k, v in MODELS.items()} | |
MODEL_FAMILES = set([model["model_family"] for model in MODELS.values()]) | |
Q_W_MODELS = [ | |
"llama-7b", | |
"llama-2-7b", | |
"llama-13b", | |
"llama-2-13b", | |
"llama-30b", | |
"llama-65b", | |
"llama-2-70b", | |
] | |
Q_W_MODELS = [f"{model}_quantized" for model in Q_W_MODELS] + [ | |
f"{model}_watermarked" for model in Q_W_MODELS | |
] | |
MODEL_NAMES = list(MODELS.keys()) + Q_W_MODELS | |
MODEL_NAMES_SORTED_BY_NAME_AND_SIZE = sorted( | |
MODEL_NAMES, | |
key=lambda x: ( | |
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"], | |
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"], | |
), | |
) | |
MODEL_NAMES_SORTED_BY_SIZE = sorted( | |
MODEL_NAMES, | |
key=lambda x: ( | |
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_size"], | |
MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["model_family"], | |
), | |
) | |
# sort MODELS_SIZE_MAPPING by value then by key | |
MODELS_SIZE_MAPPING = { | |
k: v | |
for k, v in sorted(MODELS_SIZE_MAPPING.items(), key=lambda item: (item[1], item[0])) | |
} | |
MODELS_SIZE_MAPPING_LIST = list(MODELS_SIZE_MAPPING.keys()) | |
CHAT_MODELS = [ | |
x | |
for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE | |
if MODELS[x.replace("_quantized", "").replace("_watermarked", "")]["is_chat"] | |
] | |
def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
# remove all columns that have "_loss" and "_runtime" in them | |
words_to_remove = [ | |
"epoch", | |
"loss", | |
"runtime", | |
"samples_per_second", | |
"steps_per_second", | |
"samples", | |
"results_dir", | |
] | |
df = df.loc[ | |
:, | |
~df.columns.str.contains("|".join(words_to_remove), case=False, regex=True), | |
] | |
# rename the rest of the columns by replacing "_roc_auc" with "" | |
df.columns = df.columns.str.replace("_roc_auc", "") | |
df.columns = df.columns.str.replace("eval_", "") | |
df["model_family"] = df["model_name"].apply( | |
lambda x: MODELS_FAMILY_MAPPING[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
] | |
) | |
# create a dict with the model_name and the model_family | |
model_family_dict = { | |
k: v | |
for k, v in zip( | |
df["model_name"].values.tolist(), df["model_family"].values.tolist() | |
) | |
} | |
# average the results over the 5 seeds for each model (seed column is exp_seed) | |
df_avg = df.groupby(["model_name"]).mean() | |
df_std = df.groupby(["model_name"]).std() | |
# remove the exp_seed column | |
df_avg = df_avg.drop(columns=["exp_seed"]) | |
df_std = df_std.drop(columns=["exp_seed"]) | |
df_avg["model_family"] = df_avg.index.map(model_family_dict) | |
df_std["model_family"] = df_std.index.map(model_family_dict) | |
df_avg["model_size"] = df_avg.index.map( | |
lambda x: MODELS_SIZE_MAPPING[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
] | |
) | |
df_std["model_size"] = df_std.index.map( | |
lambda x: MODELS_SIZE_MAPPING[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
] | |
) | |
# sort rows by model family then model size | |
df_avg = df_avg.sort_values( | |
by=["model_family", "model_size"], ascending=[True, True] | |
) | |
df_std = df_std.sort_values( | |
by=["model_family", "model_size"], ascending=[True, True] | |
) | |
availables_rows = [x for x in df_avg.columns if x in df_avg.index] | |
df_avg = df_avg.reindex(availables_rows) | |
availables_rows = [x for x in df_std.columns if x in df_std.index] | |
df_std = df_std.reindex(availables_rows) | |
df_avg["is_quantized"] = df_avg.index.str.contains("quantized") | |
df_avg["is_watermarked"] = df_avg.index.str.contains("watermarked") | |
df_std["is_quantized"] = df_std.index.str.contains("quantized") | |
df_std["is_watermarked"] = df_std.index.str.contains("watermarked") | |
return df_avg, df_std | |
def get_data(path) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
df, df_std = clean_dataframe(pd.read_csv(path, index_col=0)) | |
return df, df_std | |
def filter_df( | |
df: pd.DataFrame, | |
model_family_train: list, | |
model_family_test: list, | |
model_size_train: tuple, | |
model_size_test: tuple, | |
is_chat_train: bool, | |
is_chat_test: bool, | |
is_quantized_train: bool, | |
is_quantized_test: bool, | |
is_watermarked_train: bool, | |
is_watermarked_test: bool, | |
sort_by_size: bool, | |
split_chat_models: bool, | |
split_quantized_models: bool, | |
split_watermarked_models: bool, | |
filter_empty_col_row: bool, | |
is_debug: bool, | |
) -> pd.DataFrame: | |
# remove all columns and rows that have "pythia-70m" in the name | |
# filter rows | |
if is_debug: | |
st.write("No filters") | |
st.write(df) | |
df = df.loc[ | |
(df["model_size"] >= model_size_train[0] * 1e9) | |
& (df["model_size"] <= model_size_train[1] * 1e9) | |
] | |
if is_debug: | |
st.write("Filter model size train") | |
st.write(df) | |
df = df.loc[df["model_family"].isin(model_family_train)] | |
if is_debug: | |
st.write("Filter model family train") | |
st.write(df) | |
if is_chat_train != "Both": | |
df = df.loc[df["is_chat"] == is_chat_train] | |
if is_debug: | |
st.write("Filter is chat train") | |
st.write(df) | |
if is_quantized_train != "Both": | |
df = df.loc[df["is_quantized"] == is_quantized_train] | |
if is_debug: | |
st.write("Filter is quantized train") | |
st.write(df) | |
if is_watermarked_train != "Both": | |
df = df.loc[df["is_watermarked"] == is_watermarked_train] | |
if is_debug: | |
st.write("Filter is watermark train") | |
st.write(df) | |
# filter columns | |
if is_debug: | |
st.write("No filters") | |
st.write(df) | |
columns_to_keep = [] | |
for column in df.columns: | |
if ( | |
column.replace("_quantized", "").replace("_watermarked", "") | |
in MODELS.keys() | |
): | |
model_size = MODELS[ | |
column.replace("_quantized", "").replace("_watermarked", "") | |
]["model_size"] | |
if ( | |
model_size >= model_size_test[0] * 1e9 | |
and model_size <= model_size_test[1] * 1e9 | |
): | |
columns_to_keep.append(column) | |
df = df[list(sorted(list(set(columns_to_keep))))] | |
if is_debug: | |
st.write("Filter model size test") | |
st.write(df) | |
# filter columns | |
columns_to_keep = [] | |
for column in df.columns: | |
for model_family in model_family_test: | |
if ( | |
model_family | |
== MODELS[column.replace("_quantized", "").replace("_watermarked", "")][ | |
"model_family" | |
] | |
): | |
columns_to_keep.append(column) | |
df = df[list(sorted(list(set(columns_to_keep))))] | |
if is_debug: | |
st.write("Filter model family test") | |
st.write(df) | |
if is_chat_test != "Both": | |
# filter columns | |
columns_to_keep = [] | |
for column in df.columns: | |
if ( | |
MODELS[column.replace("_quantized", "").replace("_watermarked", "")][ | |
"is_chat" | |
] | |
== is_chat_test | |
): | |
columns_to_keep.append(column) | |
df = df[list(sorted(list(set(columns_to_keep))))] | |
if is_debug: | |
st.write("Filter is chat test") | |
st.write(df) | |
if is_quantized_test != "Both": | |
# filter columns | |
columns_to_keep = [] | |
for column in df.columns: | |
if "quantized" in column and is_quantized_test: | |
columns_to_keep.append(column) | |
elif "quantized" not in column and not is_quantized_test: | |
columns_to_keep.append(column) | |
df = df[list(sorted(list(set(columns_to_keep))))] | |
if is_debug: | |
st.write("Filter is quantized test") | |
st.write(df) | |
if is_watermarked_test != "Both": | |
# filter columns | |
columns_to_keep = [] | |
for column in df.columns: | |
if "watermark" in column and is_watermarked_test: | |
columns_to_keep.append(column) | |
elif "watermark" not in column and not is_watermarked_test: | |
columns_to_keep.append(column) | |
df = df[list(sorted(list(set(columns_to_keep))))] | |
if is_debug: | |
st.write("Filter is watermark test") | |
st.write(df) | |
df = df.select_dtypes(include="number") | |
if is_debug: | |
st.write("Select dtypes to be only numbers") | |
st.write(df) | |
if sort_by_size: | |
columns_in = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.columns] | |
else: | |
columns_in = [x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.columns] | |
df = df[columns_in] | |
if is_debug: | |
st.write("Sort columns") | |
st.write(df) | |
# sort rows by size according the MODELS_SIZE_MAPPING_LIST | |
if sort_by_size: | |
availables_rows = [x for x in MODEL_NAMES_SORTED_BY_SIZE if x in df.index] | |
df = df.reindex(availables_rows) | |
else: | |
availables_rows = [ | |
x for x in MODEL_NAMES_SORTED_BY_NAME_AND_SIZE if x in df.index | |
] | |
df = df.reindex(availables_rows) | |
if is_debug: | |
st.write("Sort rows") | |
st.write(df) | |
if split_chat_models: | |
# put chat models at the end of the columns | |
chat_models = [x for x in CHAT_MODELS if x in df.columns] | |
# sort chat models by size | |
chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"]) | |
df = df[[x for x in df.columns if x not in chat_models] + chat_models] | |
# put chat models at the end of the rows | |
chat_models = [x for x in CHAT_MODELS if x in df.index] | |
# sort chat models by size | |
chat_models = sorted(chat_models, key=lambda x: MODELS[x]["model_size"]) | |
df = df.reindex([x for x in df.index if x not in chat_models] + chat_models) | |
if is_debug: | |
st.write("Split chat models") | |
st.write(df) | |
if split_quantized_models: | |
# put chat models at the end of the columns | |
quantized_models = [ | |
x for x in Q_W_MODELS if x in df.columns and "quantized" in x | |
] | |
# sort chat models by size | |
quantized_models = sorted( | |
quantized_models, | |
key=lambda x: MODELS[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
]["model_size"], | |
) | |
df = df[[x for x in df.columns if x not in quantized_models] + quantized_models] | |
# put chat models at the end of the rows | |
quantized_models = [x for x in Q_W_MODELS if x in df.index and "quantized" in x] | |
# sort chat models by size | |
quantized_models = sorted( | |
quantized_models, | |
key=lambda x: MODELS[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
]["model_size"], | |
) | |
df = df.reindex( | |
[x for x in df.index if x not in quantized_models] + quantized_models | |
) | |
if split_watermarked_models: | |
# put chat models at the end of the columns | |
watermarked_models = [ | |
x for x in Q_W_MODELS if x in df.columns and "watermarked" in x | |
] | |
# sort chat models by size | |
watermarked_models = sorted( | |
watermarked_models, | |
key=lambda x: MODELS[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
]["model_size"], | |
) | |
df = df[ | |
[x for x in df.columns if x not in watermarked_models] + watermarked_models | |
] | |
# put chat models at the end of the rows | |
watermarked_models = [ | |
x for x in Q_W_MODELS if x in df.index and "watermarked" in x | |
] | |
# sort chat models by size | |
watermarked_models = sorted( | |
watermarked_models, | |
key=lambda x: MODELS[ | |
x.replace("_quantized", "").replace("_watermarked", "") | |
]["model_size"], | |
) | |
df = df.reindex( | |
[x for x in df.index if x not in watermarked_models] + watermarked_models | |
) | |
if is_debug: | |
st.write("Split chat models") | |
st.write(df) | |
if filter_empty_col_row: | |
# remove all for which the row and column are Nan | |
df = df.dropna(axis=0, how="all") | |
df = df.dropna(axis=1, how="all") | |
return df | |
df, df_std = get_data("./deberta_results.csv") | |
df_q_w, df_std_q_w = get_data("./results_qantized_watermarked.csv") | |
df = df.merge( | |
df_q_w[ | |
df_q_w.columns[ | |
df_q_w.columns.str.contains("quantized|watermarked", case=False, regex=True) | |
] | |
], | |
how="outer", | |
left_index=True, | |
right_index=True, | |
) | |
df_std = df_std.merge( | |
df_std_q_w[ | |
df_std_q_w.columns[ | |
df_std_q_w.columns.str.contains( | |
"quantized|watermarked", case=False, regex=True | |
) | |
] | |
], | |
how="outer", | |
left_index=True, | |
right_index=True, | |
) | |
df.columns = df.columns.str.replace("_y", "", regex=True) | |
df_std.columns = df_std.columns.str.replace("_y", "", regex=True) | |
df = df.drop(columns=["is_quantized_x", "is_watermarked_x"]) | |
df.update(df_q_w) | |
df_std.update(df_std_q_w) | |
df["is_chat"].fillna(False, inplace=True) | |
df_std["is_chat"].fillna(False, inplace=True) | |
df["is_watermarked"].fillna(False, inplace=True) | |
df_std["is_watermarked"].fillna(False, inplace=True) | |
df["is_quantized"].fillna(False, inplace=True) | |
df_std["is_quantized"].fillna(False, inplace=True) | |
with open("./ood_results.json", "r") as f: | |
ood_results = json.load(f) | |
ood_results = pd.DataFrame(ood_results) | |
ood_results = ood_results.set_index("model_name") | |
ood_results = ood_results.drop( | |
columns=["exp_name", "accuracy", "f1", "precision", "recall"] | |
) | |
ood_results.columns = ["seed", "Adversarial"] | |
ood_results_avg = ood_results.groupby(["model_name"]).mean() | |
ood_results_std = ood_results.groupby(["model_name"]).std() | |
st.write( | |
"""### Results Viewer 👇 | |
## From Text to Source: Results in Detecting Large Language Model-Generated Content | |
### Wissam Antoun, Benoît Sagot, Djamé Seddah | |
##### ALMAnaCH, Inria | |
##### Paper: [https://arxiv.org/abs/2309.13322](https://arxiv.org/abs/2309.13322) | |
""" | |
) | |
# filters | |
show_diff = st.sidebar.checkbox("Show Diff", value=False) | |
sort_by_size = st.sidebar.checkbox("Sort by size", value=True) | |
split_chat_models = st.sidebar.checkbox("Split chat models", value=True) | |
split_quantized_models = st.sidebar.checkbox("Split quantized models", value=True) | |
split_watermarked_models = st.sidebar.checkbox("Split watermarked models", value=True) | |
add_mean = st.sidebar.checkbox("Add mean", value=False) | |
show_std = st.sidebar.checkbox("Show std", value=False) | |
filter_empty_col_row = st.sidebar.checkbox("Filter empty col/row", value=True) | |
model_size_train = st.sidebar.slider( | |
"Train Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1 | |
) | |
model_size_test = st.sidebar.slider( | |
"Test Model Size in Billion", min_value=0, max_value=100, value=(0, 100), step=1 | |
) | |
is_chat_train = st.sidebar.selectbox("(Train) Is Chat?", [True, False, "Both"], index=2) | |
is_chat_test = st.sidebar.selectbox("(Test) Is Chat?", [True, False, "Both"], index=2) | |
is_quantized_train = st.sidebar.selectbox( | |
"(Train) Is Quantized?", [True, False, "Both"], index=1 | |
) | |
is_quantized_test = st.sidebar.selectbox( | |
"(Test) Is Quantized?", [True, False, "Both"], index=1 | |
) | |
is_watermarked_train = st.sidebar.selectbox( | |
"(Train) Is Watermark?", [True, False, "Both"], index=1 | |
) | |
is_watermarked_test = st.sidebar.selectbox( | |
"(Test) Is Watermark?", [True, False, "Both"], index=1 | |
) | |
model_family_train = st.sidebar.multiselect( | |
"Model Family Train", | |
MODEL_FAMILES, | |
default=MODEL_FAMILES, | |
) | |
model_family_test = st.sidebar.multiselect( | |
"Model Family Test", | |
list(MODEL_FAMILES) + ["Adversarial"], | |
default=MODEL_FAMILES, | |
) | |
show_values = st.sidebar.checkbox("Show Values", value=False) | |
add_adversarial = False | |
if "Adversarial" in model_family_test: | |
model_family_test.remove("Adversarial") | |
add_adversarial = True | |
sort_by_adversarial = False | |
if add_adversarial: | |
sort_by_adversarial = st.sidebar.checkbox("Sort by adversarial", value=False) | |
if st.sidebar.checkbox("Use default color scale", value=False): | |
color_scale = "Viridis_r" | |
else: | |
color_scale = viridis_rgb | |
is_debug = st.sidebar.checkbox("Debug", value=False) | |
if show_std: | |
selected_df = df_std.copy() | |
else: | |
selected_df = df.copy() | |
filtered_df = filter_df( | |
selected_df, | |
model_family_train, | |
model_family_test, | |
model_size_train, | |
model_size_test, | |
is_chat_train, | |
is_chat_test, | |
is_quantized_train, | |
is_quantized_test, | |
is_watermarked_train, | |
is_watermarked_test, | |
sort_by_size, | |
split_chat_models, | |
split_quantized_models, | |
split_watermarked_models, | |
filter_empty_col_row, | |
is_debug, | |
) | |
if show_diff: | |
# get those 3 columns {'model_size', 'model_family', 'is_chat'} | |
diag = filtered_df.values.diagonal() | |
filtered_df = filtered_df.sub(diag, axis=1) | |
# subtract each row by the diagonal | |
if add_adversarial: | |
if show_diff: | |
index = filtered_df.index | |
ood_results_avg = ood_results_avg.loc[index] | |
filtered_df = filtered_df.join(ood_results_avg.sub(diag, axis=0)) | |
else: | |
filtered_df = filtered_df.join(ood_results_avg) | |
if add_mean: | |
col_mean = filtered_df.mean(axis=1) | |
row_mean = filtered_df.mean(axis=0) | |
diag = filtered_df.values.diagonal() | |
filtered_df["mean"] = col_mean | |
filtered_df.loc["mean"] = row_mean | |
filtered_df = filtered_df * 100 | |
filtered_df = filtered_df.round(0) | |
# sort by the column called Adversarial | |
if sort_by_adversarial: | |
filtered_df = filtered_df.sort_values(by=["Adversarial"], ascending=False) | |
# check if the df has columns and rows | |
if filtered_df.shape[0] == 0: | |
st.write("No results found") | |
st.stop() | |
if filtered_df.shape[1] == 0: | |
st.write("No results found") | |
st.stop() | |
fig = px.imshow( | |
filtered_df.values, | |
x=list(filtered_df.columns), | |
y=list(filtered_df.index), | |
color_continuous_scale=color_scale, | |
contrast_rescaling=None, | |
text_auto=show_values, | |
aspect="auto", | |
) | |
# width = st.sidebar.text_input("Width", "1920") | |
# height = st.sidebar.text_input("Height", "1080") | |
# scale = st.sidebar.text_input("Scale", "1.0") | |
# margin = st.sidebar.text_input("Margin[l,r,b,t]", "200,100,100,100") | |
fig.update_traces(textfont_size=9) | |
fig.update_layout( | |
xaxis={"side": "top"}, | |
yaxis={"side": "left"}, | |
# margin=dict( | |
# l=int(margin.split(",")[0]), | |
# r=int(margin.split(",")[1]), | |
# b=int(margin.split(",")[2]), | |
# t=int(margin.split(",")[3]), | |
# ), | |
font=dict(size=10), | |
) | |
fig.update_xaxes(tickangle=45) | |
fig.update_xaxes(tickmode="linear") | |
fig.update_yaxes(tickmode="linear") | |
# change the font in the heatmap | |
st.plotly_chart(fig, use_container_width=True) | |
# if st.sidebar.button("save", key="save"): | |
# fig.write_image( | |
# "fig1.pdf", | |
# width=int(width), | |
# height=int(height), | |
# validate=True, | |
# scale=float(scale), | |
# ) | |
# plot the col mean vs model size | |
if add_mean and not show_diff: | |
# check if any of the chat models are in the filtered df columns and index | |
if len([x for x in CHAT_MODELS if x in filtered_df.columns]) > 0 or len( | |
[x for x in CHAT_MODELS if x in filtered_df.index] | |
): | |
st.warning( | |
"Chat models are in the filtered df columns or index." | |
"This will cause the mean graph to be skewed." | |
) | |
fig3 = px.scatter( | |
y=row_mean, | |
x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]], | |
# hover_data=[x for x in filtered_df.index if x not in ["mean"]], | |
color=[ | |
MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"] | |
], | |
color_discrete_sequence=px.colors.qualitative.Plotly, | |
title="", | |
# x axis title | |
labels={ | |
"x": "Target Model Size", | |
"y": "Average ROC AUC", | |
"color": "Model Family", | |
}, | |
log_x=True, | |
trendline="ols", | |
) | |
fig4 = px.scatter( | |
y=diag, | |
x=[MODELS[x]["model_size"] for x in filtered_df.columns if x not in ["mean"]], | |
# hover_data=[x for x in filtered_df.index if x not in ["mean"]], | |
color=[ | |
MODELS[x]["model_family"] for x in filtered_df.columns if x not in ["mean"] | |
], | |
color_discrete_sequence=px.colors.qualitative.Plotly, | |
title="", | |
# x axis title | |
labels={ | |
"x": "Target Model Size", | |
"y": "Self ROC AUC", | |
"color": "Model Family", | |
}, | |
log_x=True, | |
trendline="ols", | |
) | |
# put the two plots side by side | |
fig_subplot = make_subplots( | |
rows=1, | |
cols=2, | |
shared_yaxes=False, | |
subplot_titles=("Self Detection ROC AUC", "Average Target ROC AUC"), | |
) | |
for i, figure in enumerate([fig4, fig3]): | |
for trace in range(len(figure["data"])): | |
trace_data = figure["data"][trace] | |
if i == 1: | |
trace_data["showlegend"] = False | |
fig_subplot.append_trace(trace_data, row=1, col=i + 1) | |
fig_subplot.update_xaxes(type="log") | |
# y axis range | |
fig_subplot.update_yaxes(range=[0.90, 1]) | |
fig_subplot.update_layout( | |
height=500, | |
width=1200, | |
) | |
# put the legend on the bottom | |
fig_subplot.update_layout( | |
legend=dict(orientation="h", yanchor="bottom", y=-0.2, x=0.09) | |
) | |
st.plotly_chart(fig_subplot, use_container_width=True) | |
fig2 = px.scatter( | |
y=col_mean, | |
x=[MODELS_SIZE_MAPPING[x] for x in filtered_df.index if x not in ["mean"]], | |
# hover_data=[x for x in filtered_df.index if x not in ["mean"]], | |
color=[ | |
MODELS_FAMILY_MAPPING[x] for x in filtered_df.index if x not in ["mean"] | |
], | |
color_discrete_sequence=px.colors.qualitative.Plotly, | |
title="Mean vs Train Model Size", | |
log_x=True, | |
trendline="ols", | |
) | |
fig2.update_layout( | |
height=600, | |
width=900, | |
) | |
st.plotly_chart(fig2, use_container_width=False) | |