Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.interpolate import interp1d | |
from shiny import render | |
from shiny.express import input, output, ui | |
from utils import ( | |
filter_and_select, | |
plot_2d_comparison, | |
plot_color_square, | |
wens_method_heatmap, | |
plot_fcgr, | |
plot_persistence_homology, | |
) | |
import matplotlib as mpl | |
mpl.rcParams.update(mpl.rcParamsDefault) | |
############################################################# Virus Dataset ######################################################## | |
#ds = load_dataset('Hack90/virus_tiny') | |
df = pd.read_parquet('virus_ds.parquet') | |
virus = df['Organism_Name'].unique() | |
virus = {v: v for v in virus} | |
############################################################# Filter and Select ######################################################## | |
def filter_and_select(group): | |
if len(group) >= 3: | |
return group.head(3) | |
############################################################# UI ################################################################# | |
ui.page_opts(fillable=True) | |
with ui.navset_card_tab(id="tab"): | |
with ui.nav_panel("Viral Macrostructure"): | |
ui.panel_title("Do viruses have underlying structure?") | |
with ui.layout_columns(): | |
with ui.card(): | |
ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None) | |
with ui.card(): | |
ui.input_selectize( | |
"plot_type_macro", | |
"Select your method:", | |
["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"], | |
multiple=False, | |
selected=None, | |
) | |
def plot_macro(): | |
df = pd.read_parquet("virus_ds.parquet") | |
df = df[df["Organism_Name"].isin(input.virus_selector())] | |
grouped = df.groupby("Organism_Name")["Sequence"].apply(list) | |
plot_type = input.plot_type_macro() | |
if plot_type == "2D Line": | |
return plot_2d_comparison(grouped, grouped.index) | |
elif plot_type == "ColorSquare": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique()) | |
elif plot_type == "Wens Method": | |
return wens_method_heatmap(df, df["Organism_Name"].unique()) | |
elif plot_type == "Chaos Game Representation": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique()) | |
elif plot_type == "Persistant Homology": | |
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True) | |
return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"]) | |
with ui.nav_panel("Viral Microstructure"): | |
ui.panel_title("Kmer Distribution") | |
with ui.layout_columns(): | |
with ui.card(): | |
ui.input_slider("kmer", "kmer", 0, 10, 4) | |
ui.input_slider("top_k", "top:", 0, 1000, 15) | |
ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None) | |
def plot_micro(): | |
df = pd.read_csv("kmers.csv") | |
k = input.kmer() | |
top_k = input.top_k() | |
plot_type = input.plot_type() | |
if k > 0: | |
df = df[df["k"] == k].head(top_k) | |
fig, ax = plt.subplots() | |
if plot_type == "count": | |
ax.bar(df["kmer"], df["count"]) | |
ax.set_ylabel("Count") | |
elif plot_type == "percentage": | |
ax.bar(df["kmer"], df["percent"] * 100) | |
ax.set_ylabel("Percentage") | |
ax.set_title(f"Most common {k}-mers") | |
ax.set_xlabel("K-mer") | |
ax.set_xticklabels(df["kmer"], rotation=90) | |
return fig | |
with ui.nav_panel("Viral Model Training"): | |
ui.panel_title("Does context size matter for a nucleotide model?") | |
def plot_loss_rates(df, model_type): | |
x = np.linspace(0, 1, 1000) | |
loss_rates = [] | |
labels = ["32", "64", "128", "256", "512", "1024"] | |
df = df.drop(columns=["Step"]) | |
for col in df.columns: | |
y = df[col].dropna().astype("float", errors="ignore").values | |
f = interp1d(np.linspace(0, 1, len(y)), y) | |
loss_rates.append(f(x)) | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(x, loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_title(f"Loss rates for a {model_type} parameter model across context windows") | |
ax.set_xlabel("Training steps") | |
ax.set_ylabel("Loss rate") | |
return fig | |
def plot_context_size_scaling(): | |
df = pd.read_csv("14m.csv") | |
fig = plot_loss_rates(df, "14M") | |
if fig: | |
import tempfile | |
fd, path = tempfile.mkstemp(suffix=".svg") | |
fig.savefig(path) | |
return {"src": str(path), "width": "600px", "format": "svg"} | |
with ui.nav_panel("Model loss analysis"): | |
ui.panel_title("Neurips stuff") | |
with ui.card(): | |
ui.input_selectize( | |
"param_type", | |
"Select Param Type:", | |
["14", "31", "70", "160", "410"], | |
multiple=True, | |
selected=["14", "70"], | |
) | |
ui.input_selectize( | |
"model_type", | |
"Select Model Type:", | |
["pythia", "denseformer", "evo"], | |
multiple=True, | |
selected=["pythia", "denseformer"], | |
) | |
ui.input_selectize( | |
"loss_type", | |
"Select Loss Type:", | |
["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"], | |
multiple=True, | |
selected=["compliment", "cross_entropy", "headless"], | |
) | |
def plot_loss_rates_model(df, param_types, loss_types, model_types): | |
x = np.linspace(0, 1, 1000) | |
loss_rates = [] | |
labels = [] | |
for param_type in param_types: | |
for loss_type in loss_types: | |
for model_type in model_types: | |
y = df[ | |
(df["param_type"] == int(param_type)) | |
& (df["loss_type"] == loss_type) | |
& (df["model_type"] == model_type) | |
]["loss_interp"].values | |
if len(y) > 0: | |
f = interp1d(np.linspace(0, 1, len(y)), y) | |
loss_rates.append(f(x)) | |
labels.append(f"{param_type}_{loss_type}_{model_type}") | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(x, loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_xlabel("Training steps") | |
ax.set_ylabel("Loss rate") | |
return fig | |
def plot_model_scaling(): | |
df = pd.read_csv("training_data_5.csv") | |
df = df[df["epoch_interp"] > 0.035] | |
fig = plot_loss_rates_model( | |
df, input.param_type(), input.loss_type(), input.model_type() | |
) | |
if fig: | |
import tempfile | |
fd, path = tempfile.mkstemp(suffix=".svg") | |
fig.savefig(path) | |
return {"src": str(path), "width": "600px", "format": "svg"} | |
with ui.nav_panel("Scaling Laws"): | |
ui.panel_title("Params & Losses") | |
with ui.card(): | |
ui.input_selectize( | |
"model_type_scale", | |
"Select Model Type:", | |
["pythia", "denseformer", "evo"], | |
multiple=True, | |
selected=["evo", "denseformer"], | |
) | |
ui.input_selectize( | |
"loss_type_scale", | |
"Select Loss Type:", | |
["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"], | |
multiple=True, | |
selected=["cross_entropy"], | |
) | |
def plot_loss_rates_model_scale(df, loss_type, model_types): | |
df = df[df["loss_type"] == loss_type[0]] | |
params = [] | |
loss_rates = [] | |
labels = [] | |
for model_type in model_types: | |
df_new = df[df["model_type"] == model_type] | |
losses = [] | |
params_model = [] | |
for paramy in df_new["num_params"].unique(): | |
loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min() | |
par = int(paramy) | |
losses.append(loss) | |
params_model.append(par) | |
df_reorder = pd.DataFrame({"loss": losses, "params": params_model}) | |
df_reorder = df_reorder.sort_values(by="params") | |
loss_rates.append(df_reorder["loss"].to_list()) | |
params.append(df_reorder["params"].to_list()) | |
labels.append(model_type) | |
fig, ax = plt.subplots() | |
for i, loss_rate in enumerate(loss_rates): | |
ax.plot(params[i], loss_rate, label=labels[i]) | |
ax.legend() | |
ax.set_xlabel("Params") | |
ax.set_ylabel("Loss") | |
return fig | |
def plot_big_boy_model(): | |
df = pd.read_csv("training_data_5.csv") | |
fig = plot_loss_rates_model_scale( | |
df, input.loss_type_scale(), input.model_type_scale() | |
) | |
if fig: | |
import tempfile | |
fd, path = tempfile.mkstemp(suffix=".svg") | |
fig.savefig(path) | |
return {"src": str(path), "width": "600px", "format": "svg"} | |