Spaces:
Sleeping
Sleeping
File size: 14,316 Bytes
d8a3b21 d11a268 9c65bf3 b30f466 3590429 df8a5ee 3590429 6e33bf5 5a036ce d8a3b21 a3d91c6 2f194e3 d8a3b21 56cda59 cacb9ec 56cda59 cacb9ec 56cda59 cacb9ec 04a20cd d8a3b21 19a59f0 d8a3b21 19a59f0 3590429 2fa0dd7 3590429 2fa0dd7 3590429 d5f53fb 3590429 df8a5ee cacb9ec df8a5ee cacb9ec df8a5ee 56cda59 9ad1347 df8a5ee 3590429 d5f53fb c9b64a7 d5f53fb 3590429 179f7b9 3590429 d5f53fb 3590429 d5f53fb 3590429 d5f53fb 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 c9b64a7 3590429 0e3a4b3 c9b64a7 3590429 9a0a1ca 0e3a4b3 3590429 884d71b 0e2e1d7 884d71b 2e10ae8 3590429 7cee527 3590429 7cee527 3590429 2e10ae8 3590429 7cee527 3590429 2e10ae8 3590429 23f828d 3590429 d9e840a 27c7f11 3590429 f6ccd04 3590429 27c7f11 a513cc4 3590429 27c7f11 d9e840a 27c7f11 3590429 27c7f11 3590429 0e3a4b3 932765f 3590429 c467935 0e3a4b3 3590429 bab98a8 3590429 7cee527 3590429 bab98a8 3590429 7cee527 3590429 bab98a8 3590429 bab98a8 3590429 bab98a8 3590429 bab98a8 6eb481d bab98a8 3590429 bab98a8 3590429 0e3a4b3 f422dcf 3590429 3ef0d7b 0e3a4b3 44feaa5 0f3cccd 44feaa5 0f3cccd 44feaa5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from shiny import render
from shiny.express import input, output, ui
from utils import (
filter_and_select,
plot_2d_comparison,
plot_color_square,
wens_method_heatmap,
plot_fcgr,
plot_persistence_homology,
plot_distrobutions
)
import os
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
############################################################# Virus Dataset ########################################################
#ds = load_dataset('Hack90/virus_tiny')
df = pd.read_parquet('virus_ds.parquet')
virus = df['Organism_Name'].unique()
virus = {v: v for v in virus}
df_new = pd.read_parquet("virus.parquet", columns= ['organism_name'])
df_new = df_new.groupby('organism_name').apply(lambda x: x.head(100) if len(x) > 10 else None).reset_index(drop=True)
filter_species = df_new.organism_name.value_counts().reset_index()[df_new.organism_name.value_counts().reset_index()['count'] > 40 ]['organism_name'][1:].tolist()
df_old = pd.read_parquet("virus.parquet", columns =['seq', 'organism_name'])
MASTER_DF = df_old[df_old['organism_name'].isin(filter_species)].copy()
del df_new
del df_old
virus_new = {v: v for v in filter_species}
loss_typesss = pd.read_csv("training_data_5.csv")['loss_type'].unique().tolist()
model_typesss = pd.read_csv("training_data_5.csv")['model_type'].unique().tolist()
param_typesss = pd.read_csv("training_data_5.csv")['param_type'].unique().tolist()
############################################################# Filter and Select ########################################################
def filter_and_select(group):
if len(group) >= 3:
return group.head(3)
############################################################# UI #################################################################
ui.page_opts(fillable=True)
with ui.navset_card_tab(id="tab"):
with ui.nav_panel("Viral Macrostructure"):
ui.panel_title("Do viruses have underlying structure?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None)
with ui.card():
ui.input_selectize(
"plot_type_macro",
"Select your method:",
["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
multiple=False,
selected=None,
)
@render.plot()
def plot_macro():
df = pd.read_parquet("virus_ds.parquet")
df = df[df["Organism_Name"].isin(input.virus_selector())]
grouped = df.groupby("Organism_Name")["Sequence"].apply(list)
plot_type = input.plot_type_macro()
if plot_type == "2D Line":
return plot_2d_comparison(grouped, grouped.index)
elif plot_type == "ColorSquare":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique())
elif plot_type == "Wens Method":
return wens_method_heatmap(df, df["Organism_Name"].unique())
elif plot_type == "Chaos Game Representation":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique())
elif plot_type == "Persistant Homology":
filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"])
with ui.nav_panel("Viral Genome Distributions"):
ui.panel_title("How does sequence distribution vary across sequence length?")
with ui.layout_columns():
with ui.card():
ui.input_selectize("virus_selector_1", "Select your viruses:", virus_new, multiple=True, selected=None)
with ui.card():
ui.input_slider(
"basepair","Select basepair",0, 10000, 15
)
@render.plot()
def plot_distro():
df = MASTER_DF[MASTER_DF["organism_name"].isin(input.virus_selector_1())].copy()
grouped = df.groupby("organism_name")["seq"].apply(list)
return plot_distrobutions(grouped, grouped.index, input.basepair())
with ui.nav_panel("Viral Microstructure"):
ui.panel_title("Kmer Distribution")
with ui.layout_columns():
with ui.card():
ui.input_slider("kmer", "kmer", 0, 10, 4)
ui.input_slider("top_k", "top:", 0, 1000, 15)
ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None)
@render.plot()
def plot_micro():
df = pd.read_csv("kmers.csv")
k = input.kmer()
top_k = input.top_k()
plot_type = input.plot_type()
if k > 0:
df = df[df["k"] == k].head(top_k)
fig, ax = plt.subplots()
if plot_type == "count":
ax.bar(df["kmer"], df["count"])
ax.set_ylabel("Count")
elif plot_type == "percentage":
ax.bar(df["kmer"], df["percent"] * 100)
ax.set_ylabel("Percentage")
ax.set_title(f"Most common {k}-mers")
ax.set_xlabel("K-mer")
ax.set_xticklabels(df["kmer"], rotation=90)
return fig
with ui.nav_panel("Viral Model Training"):
ui.panel_title("Does context size matter for a nucleotide model?")
def plot_loss_rates(df, model_type):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = ["32", "64", "128", "256", "512", "1024"]
df = df.drop(columns=["Step"])
for col in df.columns:
y = df[col].dropna().astype("float", errors="ignore").values
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_title(f"Loss rates for a {model_type} parameter model across context windows")
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_context_size_scaling():
df = pd.read_csv("14m.csv")
fig = plot_loss_rates(df, "14M")
if fig:
return fig
with ui.nav_panel("Model loss analysis"):
ui.panel_title("Paper stuff")
with ui.card():
ui.input_selectize(
"param_type",
"Select Param Type:",
param_typesss,
multiple=True,
)
ui.input_selectize(
"model_type",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["pythia", "denseformer"],
)
ui.input_selectize(
"loss_type",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["compliment", "cross_entropy", "headless"],
)
def plot_loss_rates_model(df, param_types, loss_types, model_types):
x = np.linspace(0, 1, 1000)
loss_rates = []
labels = []
for param_type in param_types:
for loss_type in loss_types:
for model_type in model_types:
y = df[
(df["param_type"] == float(param_type))
& (df["loss_type"] == loss_type)
& (df["model_type"] == model_type)
]["loss_interp"].values
if len(y) > 0:
f = interp1d(np.linspace(0, 1, len(y)), y)
loss_rates.append(f(x))
labels.append(f"{param_type}_{loss_type}_{model_type}")
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(x, loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Training steps")
ax.set_ylabel("Loss rate")
return fig
@render.plot()
def plot_model_scaling():
df = pd.read_csv("training_data_5.csv")
df = df[df["epoch_interp"] > 0.035]
fig = plot_loss_rates_model(
df, input.param_type(), input.loss_type(), input.model_type()
)
if fig:
return fig
with ui.nav_panel("Scaling Laws"):
ui.panel_title("Params & Losses")
with ui.card():
ui.input_selectize(
"model_type_scale",
"Select Model Type:",
model_typesss,
multiple=True,
selected=["evo", "denseformer"],
)
ui.input_selectize(
"loss_type_scale",
"Select Loss Type:",
loss_typesss,
multiple=True,
selected=["cross_entropy"],
)
def plot_loss_rates_model_scale(df, loss_type, model_types):
df = df[df["loss_type"] == loss_type[0]]
params = []
loss_rates = []
labels = []
for model_type in model_types:
df_new = df[df["model_type"] == model_type]
losses = []
params_model = []
for paramy in df_new["num_params"].unique():
loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min()
par = int(paramy)
losses.append(loss)
params_model.append(par)
df_reorder = pd.DataFrame({"loss": losses, "params": params_model})
df_reorder = df_reorder.sort_values(by="params")
loss_rates.append(df_reorder["loss"].to_list())
params.append(df_reorder["params"].to_list())
labels.append(model_type)
fig, ax = plt.subplots()
for i, loss_rate in enumerate(loss_rates):
ax.plot(params[i], loss_rate, label=labels[i])
ax.legend()
ax.set_xlabel("Params")
ax.set_ylabel("Loss")
return fig
@render.plot()
def plot_big_boy_model():
df = pd.read_csv("training_data_5.csv")
fig = plot_loss_rates_model_scale(
df, input.loss_type_scale(), input.model_type_scale()
)
if fig:
return fig
with ui.nav_panel("Logits View"):
ui.panel_title("Logits et all")
with ui.card():
ui.input_selectize(
"model_bigness",
"Select Model size:",
["14", "31", "70", "160", "410"],
multiple=True,
selected=["70", "160"],
)
ui.input_selectize(
"loss_loss_loss",
"Select Loss Type:",
["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"],
multiple=True,
selected=["cross_entropy"],
)
ui.input_selectize(
"logits_select",
"Select logits:",
["1", "2", "3", "4", "5", "6", "7", "8"],
multiple=True,
selected=["6"],
)
def plot_logits_representation(model_bigness, loss_type, logits):
num_rows = 2 # Number of rows in the subplot grid
num_cols = len(logits) # Number of columns based on the number of selected logits
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10))
# axs = axs.flatten() # Flatten axs to handle 1D indexing
for size in model_bigness:
for loss in loss_type:
file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy"
if os.path.exists(file_name):
data = np.load(file_name, allow_pickle=True).item()
for k, logit in enumerate(logits):
if len(logits) == 1:
logit_index = int(logit) - 1
axs[0].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[0].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[0].set_title(f'Logit: {logit}- Single')
axs[0].legend()
axs[1].plot(data['lm_logits_y_full_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
axs[1].plot(data['shift_labels_y_full_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
axs[1].set_title(f'Logit: {logit} - Full')
axs[1].legend()
else:
print(f"File not found: {file_name}")
for k in range(len(logits), num_cols):
fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected
plt.tight_layout()
return fig
@render.plot()
def plot_logits_representation_ui():
fig = plot_logits_representation(
input.model_bigness(), input.loss_loss_loss(), input.logits_select()
)
if fig:
return fig
|