Spaces:
Sleeping
Sleeping
File size: 16,686 Bytes
d8a3b21 d11a268 9c65bf3 b30f466 3590429 df8a5ee 3590429 6e33bf5 eaac7ee 5a036ce d8a3b21 a3d91c6 2f194e3 d8a3b21 434be3b b18e1b5 04a20cd d8a3b21 19a59f0 d8a3b21 19a59f0 3590429 2fa0dd7 3590429 2fa0dd7 3590429 d5f53fb 3590429 df8a5ee b18e1b5 df8a5ee cacb9ec d9d2b55 df8a5ee da3072e d9d2b55 3590429 d5f53fb c9b64a7 d5f53fb 3590429 179f7b9 3590429 d5f53fb 3590429 d5f53fb 3590429 d5f53fb 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 0e3a4b3 c9b64a7 3590429 9a0a1ca 0e3a4b3 3590429 884d71b 0e2e1d7 884d71b 2e10ae8 3590429 7cee527 3590429 7cee527 3590429 2e10ae8 3590429 7cee527 3590429 2e10ae8 3590429 23f828d 3590429 d9e840a 27c7f11 3590429 f6ccd04 3590429 27c7f11 a513cc4 3590429 27c7f11 d9e840a 27c7f11 3590429 27c7f11 3590429 0e3a4b3 932765f 3590429 c467935 0e3a4b3 3590429 bab98a8 3590429 7cee527 3590429 bab98a8 3590429 7cee527 3590429 bab98a8 3590429 bab98a8 3590429 bab98a8 3590429 bab98a8 6eb481d bab98a8 3590429 bab98a8 3590429 0e3a4b3 f422dcf 3590429 3ef0d7b 0e3a4b3 44feaa5 0f3cccd 44feaa5 0f3cccd 44feaa5 |
|
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from shiny import render
from shiny.express import input, output, ui
from utils import (
filter_and_select,
plot_2d_comparison,
plot_color_square,
wens_method_heatmap,
plot_fcgr,
plot_persistence_homology,
plot_distrobutions
)
import os
import seaborn as sns
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
############################################################# Virus Dataset ########################################################
#ds = load_dataset('Hack90/virus_tiny')
df = pd.read_parquet('virus_ds.parquet')
virus = df['Organism_Name'].unique()
virus = {v: v for v in virus}
df_new = pd.read_parquet("distro.parquet", columns= ['organism_name'])['organism_name'].tolist()
MASTER_DF = pd.read_parquet("distro.parquet")
virus_new = {v: v for v in df_new}
loss_typesss = pd.read_csv("training_data_5.csv")['loss_type'].unique().tolist()
model_typesss = pd.read_csv("training_data_5.csv")['model_type'].unique().tolist()
param_typesss = pd.read_csv("training_data_5.csv")['param_type'].unique().tolist()
############################################################# Filter and Select ########################################################
def filter_and_select(group):
if len(group) >= 3:
return group.head(3)
############################################################# UI #################################################################
ui.page_opts(fillable=True)
with ui.navset_card_tab(id="tab"):
with ui.nav_panel("Viral Macrostructure"):
ui.panel_title("Do viruses have underlying structure?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None)
with ui.card():
ui.input_selectize(
"plot_type_macro",
"Select your method:",
["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
multiple=False,
selected=None,
)
@render.plot()
def plot_macro():
df = pd.read_parquet("virus_ds.parquet")
df = df[df["Organism_Name"].isin(input.virus_selector())]
grouped = df.groupby("Organism_Name")["Sequence"].apply(list)
plot_type = input.plot_type_macro()
if plot_type == "2D Line":
return plot_2d_comparison(grouped, grouped.index)
elif plot_type == "ColorSquare":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique())
elif plot_type == "Wens Method":
return wens_method_heatmap(df, df["Organism_Name"].unique())
elif plot_type == "Chaos Game Representation":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique())
elif plot_type == "Persistant Homology":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"])
with ui.nav_panel("Viral Genome Distributions"):
ui.panel_title("How does sequence distribution vary for a specie?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector_1", "Select your viruses:", virus_new, multiple=True, selected=None)
# with ui.card():
ui.input_selectize(
"plot_type_distro",
"Select your distrobution variance view:",
["Variance across bp", "Standard deviation across bp", "Full Genome Distrobution"],
multiple=False,
selected="Full Genome Distrobution",
)
@render.plot()
def plot_distro_new():
import seaborn as sns
plot_type = input.plot_type_distro()
if plot_type == "Full Genome Distrobution":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
df = df.explode('charts').copy()
ax = sns.histplot(data=df, x='charts', hue='organism_name', stat='density')
ax.set_title("Distribution")
ax.set_xlabel("Distance from mean")
ax.set_ylabel("Density")
return ax
elif plot_type == "Standard deviation across bp":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
dfs = []
for organism in input.virus_selector_1():
df_tiny = df[df['organism_name'] == organism].copy()
y = df_tiny['std'].values[0].tolist()
x = [x for x in range(len(y))]
df_tiny = pd.DataFrame()
df_tiny['y'] = y
df_tiny['x'] = x
df_tiny['organism'] = organism
dfs.append(df_tiny)
df_k = pd.DataFrame()
df_k = pd.concat(dfs)
df_k = df_k.explode(column =['x', 'y']).copy()
ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism')
ax.set_title("Standard Deviation across basepairs")
ax.set_xlabel("Basepair")
ax.set_ylabel("Std")
return ax
elif plot_type == "Variance across bp":
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
dfs = []
for organism in input.virus_selector_1():
df_tiny = df[df['organism_name'] == organism].copy()
y = df_tiny['var'].values[0].tolist()
x = [x for x in range(len(y))]
df_tiny = pd.DataFrame()
df_tiny['y'] = y
df_tiny['x'] = x
df_tiny['organism'] = organism
dfs.append(df_tiny)
df_k = pd.DataFrame()
df_k = pd.concat(dfs)
df_k = df_k.explode(column =['x', 'y']).copy()
ax = sns.lineplot(data=df_k, x='x',y = 'y', hue='organism')
ax.set_title("Variance across basepairs")
ax.set_xlabel("Basepair")
ax.set_ylabel("Variance")
return ax
with ui.nav_panel("Viral Microstructure"):
ui.panel_title("Kmer Distribution")
with ui.layout_columns():
with ui.card():
ui.input_slider("kmer", "kmer", 0, 10, 4)
ui.input_slider("top_k", "top:", 0, 1000, 15)
ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None)
@render.plot()
def plot_micro():
df = pd.read_csv("kmers.csv")
k = input.kmer()
top_k = input.top_k()
plot_type = input.plot_type()
if k > 0:
df = df[df["k"] == k].head(top_k)
fig, ax = plt.subplots()
if plot_type == "count":
ax.bar(df["kmer"], df["count"])
ax.set_ylabel("Count")
elif plot_type == "percentage":
ax.bar(df["kmer"], df["percent"] * 100)
ax.set_ylabel("Percentage")
ax.set_title(f"Most common {k}-mers")
ax.set_xlabel("K-mer")
ax.set_xticklabels(df["kmer"], rotation=90)
return fig
with ui.nav_panel("Viral Model Training"):
ui.panel_title("Does context size matter for a nucleotide model?")
def plot_loss_rates(df, model_type):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = ["32", "64", "128", "256", "512", "1024"]
df = df.drop(columns=["Step"])
for col in df.columns:
y = df[col].dropna().astype("float", errors="ignore").values
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_title(f"Loss rates for a {model_type} parameter model across context windows")
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_context_size_scaling():
df = pd.read_csv("14m.csv")
fig = plot_loss_rates(df, "14M")
if fig:
return fig
with ui.nav_panel("Model loss analysis"):
ui.panel_title("Paper stuff")
with ui.card():
ui.input_selectize(
"param_type",
"Select Param Type:",
param_typesss,
multiple=True,
)
ui.input_selectize(
"model_type",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["pythia", "denseformer"],
)
ui.input_selectize(
"loss_type",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["compliment", "cross_entropy", "headless"],
)
def plot_loss_rates_model(df, param_types, loss_types, model_types):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = []
for param_type in param_types:
for loss_type in loss_types:
for model_type in model_types:
y = df[
(df["param_type"] == float(param_type))
& (df["loss_type"] == loss_type)
& (df["model_type"] == model_type)
]["loss_interp"].values
if len(y) > 0:
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
labels.append(f"{param_type}_{loss_type}_{model_type}")
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_model_scaling():
df = pd.read_csv("training_data_5.csv")
df = df[df["epoch_interp"] > 0.035]
fig = plot_loss_rates_model(
df, input.param_type(), input.loss_type(), input.model_type()
)
if fig:
return fig
with ui.nav_panel("Scaling Laws"):
ui.panel_title("Params & Losses")
with ui.card():
ui.input_selectize(
"model_type_scale",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["evo", "denseformer"],
)
ui.input_selectize(
"loss_type_scale",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["cross_entropy"],
)
def plot_loss_rates_model_scale(df, loss_type, model_types):
df = df[df["loss_type"] == loss_type[0]]
params = []
loss_rates = []
labels = []
for model_type in model_types:
df_new = df[df["model_type"] == model_type]
losses = []
params_model = []
for paramy in df_new["num_params"].unique():
loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min()
par = int(paramy)
losses.append(loss)
params_model.append(par)
df_reorder = pd.DataFrame({"loss": losses, "params": params_model})
df_reorder = df_reorder.sort_values(by="params")
loss_rates.append(df_reorder["loss"].to_list())
params.append(df_reorder["params"].to_list())
labels.append(model_type)
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(params[i], loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Params")
ax.set_ylabel("Loss")
return fig
@render.plot()
def plot_big_boy_model():
df = pd.read_csv("training_data_5.csv")
fig = plot_loss_rates_model_scale(
df, input.loss_type_scale(), input.model_type_scale()
)
if fig:
return fig
with ui.nav_panel("Logits View"):
ui.panel_title("Logits et all")
with ui.card():
ui.input_selectize(
"model_bigness",
"Select Model size:",
["14", "31", "70", "160", "410"],
multiple=True,
selected=["70", "160"],
)
ui.input_selectize(
"loss_loss_loss",
"Select Loss Type:",
["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"],
multiple=True,
selected=["cross_entropy"],
)
ui.input_selectize(
"logits_select",
"Select logits:",
["1", "2", "3", "4", "5", "6", "7", "8"],
multiple=True,
selected=["6"],
)
def plot_logits_representation(model_bigness, loss_type, logits):
num_rows = 2 # Number of rows in the subplot grid
num_cols = len(logits) # Number of columns based on the number of selected logits
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10))
# axs = axs.flatten() # Flatten axs to handle 1D indexing
for size in model_bigness:
for loss in loss_type:
file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy"
if os.path.exists(file_name):
data = np.load(file_name, allow_pickle=True).item()
for k, logit in enumerate(logits):
if len(logits) == 1:
logit_index = int(logit) - 1
axs[0].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[0].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[0].set_title(f'Logit: {logit}- Single')
axs[0].legend()
axs[1].plot(data['lm_logits_y_full_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[1].plot(data['shift_labels_y_full_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[1].set_title(f'Logit: {logit} - Full')
axs[1].legend()
else:
print(f"File not found: {file_name}")
for k in range(len(logits), num_cols):
fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected
plt.tight_layout()
return fig
@render.plot()
def plot_logits_representation_ui():
fig = plot_logits_representation(
input.model_bigness(), input.loss_loss_loss(), input.logits_select()
)
if fig:
return fig
|