File size: 10,378 Bytes
d8a3b21
d11a268
9c65bf3
 
b30f466
 
3590429
 
 
 
 
 
 
 
d8a3b21
5a036ce
 
d8a3b21
 
 
a3d91c6
2f194e3
d8a3b21
 
 
 
 
 
 
 
 
19a59f0
d8a3b21
19a59f0
3590429
2fa0dd7
 
 
 
3590429
2fa0dd7
3590429
 
 
 
 
 
 
d5f53fb
3590429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5f53fb
 
 
c9b64a7
d5f53fb
3590429
 
179f7b9
3590429
 
d5f53fb
 
3590429
 
 
 
d5f53fb
3590429
 
 
 
 
 
d5f53fb
 
3590429
 
 
 
c9b64a7
3590429
 
c9b64a7
 
3590429
 
c9b64a7
3590429
c9b64a7
 
 
 
 
 
3590429
 
 
c9b64a7
3590429
4206a2c
c9b64a7
3590429
 
9a0a1ca
3590429
 
 
9a0a1ca
3590429
 
884d71b
 
 
2e10ae8
3590429
 
 
 
 
 
2e10ae8
3590429
 
 
 
 
 
2e10ae8
3590429
 
 
 
 
 
 
d9e840a
27c7f11
 
 
 
 
 
3590429
 
 
 
 
27c7f11
a513cc4
 
3590429
27c7f11
 
d9e840a
27c7f11
3590429
 
27c7f11
3590429
c467935
932765f
3590429
 
 
 
 
c467935
3590429
 
 
c467935
3590429
 
bab98a8
 
 
 
3590429
 
 
 
 
 
bab98a8
3590429
 
 
 
 
 
 
bab98a8
3590429
bab98a8
 
 
 
3590429
bab98a8
 
3590429
 
 
 
 
 
 
 
 
bab98a8
 
 
6eb481d
bab98a8
3590429
 
bab98a8
3590429
d079bda
f422dcf
3590429
 
 
 
3ef0d7b
3590429
6bbcf95
3590429
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from shiny import render
from shiny.express import input, output, ui
from utils import (
    filter_and_select,
    plot_2d_comparison,
    plot_color_square,
    wens_method_heatmap,
    plot_fcgr,
    plot_persistence_homology,
)

import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)


############################################################# Virus Dataset ########################################################
#ds = load_dataset('Hack90/virus_tiny')
df = pd.read_parquet('virus_ds.parquet')
virus = df['Organism_Name'].unique()
virus = {v: v for v in virus}

############################################################# Filter and Select ########################################################
def filter_and_select(group):
    if len(group) >= 3:
        return group.head(3)
    
############################################################# UI #################################################################

ui.page_opts(fillable=True)

with ui.navset_card_tab(id="tab"):
    with ui.nav_panel("Viral Macrostructure"):
        ui.panel_title("Do viruses have underlying structure?")
        with ui.layout_columns():
            with ui.card():
                ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None)
            with ui.card():
                ui.input_selectize(
                    "plot_type_macro",
                    "Select your method:",
                    ["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
                    multiple=False,
                    selected=None,
                )

        @render.plot()
        def plot_macro():
            df = pd.read_parquet("virus_ds.parquet")
            df = df[df["Organism_Name"].isin(input.virus_selector())]
            grouped = df.groupby("Organism_Name")["Sequence"].apply(list)

            plot_type = input.plot_type_macro()
            if plot_type == "2D Line":
                return plot_2d_comparison(grouped, grouped.index)
            elif plot_type == "ColorSquare":
                filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
                return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique())
            elif plot_type == "Wens Method":
                return wens_method_heatmap(df, df["Organism_Name"].unique())
            elif plot_type == "Chaos Game Representation":
                filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
                return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique())
            elif plot_type == "Persistant Homology":
                filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
                return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"])

    with ui.nav_panel("Viral Microstructure"):
        ui.panel_title("Kmer Distribution")
        with ui.layout_columns():
            with ui.card():
                ui.input_slider("kmer", "kmer", 0, 10, 4)
                ui.input_slider("top_k", "top:", 0, 1000, 15)
                ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None)

        @render.plot()
        def plot_micro():
            df = pd.read_csv("kmers.csv")
            k = input.kmer()
            top_k = input.top_k()
            plot_type = input.plot_type()

            if k > 0:
                df = df[df["k"] == k].head(top_k)
                fig, ax = plt.subplots()
                if plot_type == "count":
                    ax.bar(df["kmer"], df["count"])
                    ax.set_ylabel("Count")
                elif plot_type == "percentage":
                    ax.bar(df["kmer"], df["percent"] * 100)
                    ax.set_ylabel("Percentage")
                ax.set_title(f"Most common {k}-mers")
                ax.set_xlabel("K-mer")
                ax.set_xticklabels(df["kmer"], rotation=90)
                return fig

    with ui.nav_panel("Viral Model Training"):
        ui.panel_title("Does context size matter for a nucleotide model?")

        def plot_loss_rates(df, model_type):
            x = np.linspace(0, 1, 1000)
            loss_rates = []
            labels = ["32", "64", "128", "256", "512", "1024"]
            df = df.drop(columns=["Step"])
            for col in df.columns:
                y = df[col].dropna().astype("float", errors="ignore").values
                f = interp1d(np.linspace(0, 1, len(y)), y)
                loss_rates.append(f(x))
            fig, ax = plt.subplots()
            for i, loss_rate in enumerate(loss_rates):
                ax.plot(x, loss_rate, label=labels[i])
            ax.legend()
            ax.set_title(f"Loss rates for a {model_type} parameter model across context windows")
            ax.set_xlabel("Training steps")
            ax.set_ylabel("Loss rate")
            return fig

        @render.image
        def plot_context_size_scaling():
            df = pd.read_csv("14m.csv")
            fig = plot_loss_rates(df, "14M")
            if fig:
                import tempfile

                fd, path = tempfile.mkstemp(suffix=".svg")
                fig.savefig(path)
                return {"src": str(path), "width": "600px", "format": "svg"}

    with ui.nav_panel("Model loss analysis"):
        ui.panel_title("Neurips stuff")
        with ui.card():
            ui.input_selectize(
                "param_type",
                "Select Param Type:",
                ["14", "31", "70", "160", "410"],
                multiple=True,
                selected=["14", "70"],
            )
            ui.input_selectize(
                "model_type",
                "Select Model Type:",
                ["pythia", "denseformer", "evo"],
                multiple=True,
                selected=["pythia", "denseformer"],
            )
            ui.input_selectize(
                "loss_type",
                "Select Loss Type:",
                ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
                multiple=True,
                selected=["compliment", "cross_entropy", "headless"],
            )

        def plot_loss_rates_model(df, param_types, loss_types, model_types):
            x = np.linspace(0, 1, 1000)
            loss_rates = []
            labels = []
            for param_type in param_types:
                for loss_type in loss_types:
                    for model_type in model_types:
                        y = df[
                            (df["param_type"] == int(param_type))
                            & (df["loss_type"] == loss_type)
                            & (df["model_type"] == model_type)
                        ]["loss_interp"].values
                        if len(y) > 0:
                            f = interp1d(np.linspace(0, 1, len(y)), y)
                            loss_rates.append(f(x))
                            labels.append(f"{param_type}_{loss_type}_{model_type}")
            fig, ax = plt.subplots()
            for i, loss_rate in enumerate(loss_rates):
                ax.plot(x, loss_rate, label=labels[i])
            ax.legend()
            ax.set_xlabel("Training steps")
            ax.set_ylabel("Loss rate")
            return fig

        @render.image
        def plot_model_scaling():
            df = pd.read_csv("training_data_5.csv")
            df = df[df["epoch_interp"] > 0.035]
            fig = plot_loss_rates_model(
                df, input.param_type(), input.loss_type(), input.model_type()
            )
            if fig:
                import tempfile

                fd, path = tempfile.mkstemp(suffix=".svg")
                fig.savefig(path)
                return {"src": str(path), "width": "600px", "format": "svg"}

    with ui.nav_panel("Scaling Laws"):
        ui.panel_title("Params & Losses")
        with ui.card():
            ui.input_selectize(
                "model_type_scale",
                "Select Model Type:",
                ["pythia", "denseformer", "evo"],
                multiple=True,
                selected=["evo", "denseformer"],
            )
            ui.input_selectize(
                "loss_type_scale",
                "Select Loss Type:",
                ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
                multiple=True,
                selected=["cross_entropy"],
            )

        def plot_loss_rates_model_scale(df, loss_type, model_types):
            df = df[df["loss_type"] == loss_type[0]]
            params = []
            loss_rates = []
            labels = []
            for model_type in model_types:
                df_new = df[df["model_type"] == model_type]
                losses = []
                params_model = []
                for paramy in df_new["num_params"].unique():
                    loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min()
                    par = int(paramy)
                    losses.append(loss)
                    params_model.append(par)
                df_reorder = pd.DataFrame({"loss": losses, "params": params_model})
                df_reorder = df_reorder.sort_values(by="params")
                loss_rates.append(df_reorder["loss"].to_list())
                params.append(df_reorder["params"].to_list())
                labels.append(model_type)
            fig, ax = plt.subplots()
            for i, loss_rate in enumerate(loss_rates):
                ax.plot(params[i], loss_rate, label=labels[i])
            ax.legend()
            ax.set_xlabel("Params")
            ax.set_ylabel("Loss")
            return fig

        @render.image
        def plot_big_boy_model():
            df = pd.read_csv("training_data_5.csv")
            fig = plot_loss_rates_model_scale(
                df, input.loss_type_scale(), input.model_type_scale()
            )
            if fig:
                import tempfile

                fd, path = tempfile.mkstemp(suffix=".svg")
                fig.savefig(path)
                return {"src": str(path), "width": "600px", "format": "svg"}