| |
|
|
| """ |
| Usage: |
| python probe_animation.py /path/to/xp [/path/to/other/xp ...] |
| Where `/path/to/xp` is a folder containing `probe/probe.0.jsonl` |
| """ |
|
|
| import math |
| from pathlib import Path |
| import json |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import matplotlib.animation as animation |
| import matplotlib |
| from matplotlib import rc |
| from multiprocessing import Pool |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("probe_folders", help="contains a `probe` folder", nargs="+") |
| args = parser.parse_args() |
|
|
| rc("animation", html="jshtml") |
| matplotlib.rcParams["animation.embed_limit"] = 2**128 |
|
|
| DEFAULT_QUANTILES = [0.001, 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99, 0.999] |
| DATAS_PER_FILE = {} |
| datas = [] |
| NUM_LAYERS = 0 |
|
|
| for f in args.probe_folders: |
| name = Path(f).name |
| print("Loading ", name) |
| file = None |
| for probe_test in ["probe/probe.0.jsonl", "probe.json"]: |
| if (Path(f) / probe_test).exists(): |
| file = Path(f) / probe_test |
| assert file is not None, "Could not find probe json file" |
| data = file.read_text() |
| datas = [] |
| for line in data.splitlines(): |
| if line == "": |
| continue |
| datas.append(json.loads(line)) |
| datas[-1].setdefault("quantiles", DEFAULT_QUANTILES) |
|
|
| DATAS_PER_FILE[name] = datas |
| |
| |
|
|
| NUM_LAYERS = max( |
| NUM_LAYERS, |
| 1 |
| + max( |
| int(k.split("FSDPTransformer.layers.", 1)[1].split(".")[0]) |
| for k in datas[0]["data"].keys() |
| if k.startswith("FSDPTransformer.layers.") and k.endswith("::w") |
| ), |
| ) |
| for d in datas: |
| d["meta"]["it"] = d["meta"]["global_step"] |
| assert NUM_LAYERS > 0, "Couldn't deduce the model depth" |
|
|
|
|
| def get_mean_quantiles(df, names): |
| means = [] |
| quantiles = [[] for _ in df["quantiles"]] |
| for name in names: |
| d = df["data"][name] |
| for qi, qval in enumerate(d["quantiles"]): |
| quantiles[qi].append(qval) |
| means.append(d["mean"]) |
| return np.array(means), np.array(quantiles) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| COLORS = [f"tab:{x}" for x in ["blue", "orange", "green", "red"]] |
|
|
|
|
| class Plotter: |
| def __init__(self, ax, run, layers, color, timesteps) -> None: |
| datas = DATAS_PER_FILE[run] |
| for i in timesteps: |
| for layer in layers: |
| if layer not in datas[i]["data"]: |
| raise ValueError(f"Run `{run}`: layer `{layer}` not found!") |
| self.x = np.arange(0, len(layers), 1) |
| (self.mean,) = ax.plot(self.x, self.x, color=color, label=run) |
| self.fills = [] |
| self.animate_data = [get_mean_quantiles(datas[i], layers) for i in timesteps] |
| self.iters = [datas[i]["meta"]["it"] for i in timesteps] |
| self.minimum = min( |
| [ |
| min(np.nanmin(quants[3]), np.nanmin(means)) |
| for means, quants in self.animate_data |
| ] |
| ) |
| self.maximum = max( |
| [ |
| max(np.nanmax(quants[-4]), np.nanmax(means)) |
| for means, quants in self.animate_data |
| ] |
| ) |
| if not math.isfinite(self.minimum) or not math.isfinite(self.maximum): |
| raise ValueError( |
| f"Layer `{layers[0]}`: invalid min/max computed: {self.minimum}/{self.maximum}" |
| ) |
| self.ax = ax |
| self.color = color |
| self.run_name = run |
|
|
| def animate(self, i): |
| for f in self.fills: |
| f.remove() |
| self.fills.clear() |
| means, quants = self.animate_data[i] |
| self.fills += [ |
| self.ax.fill_between( |
| x=self.x, y1=quants[j], y2=quants[-1 - j], alpha=0.2, color=self.color |
| ) |
| for j in [0, 1, 2, 3, 4, 5] |
| ] |
| self.mean.set_ydata(means) |
| self.mean.set_label(f"{self.run_name} it={self.iters[i]}") |
| return self.mean |
|
|
|
|
| def plot_depth_distr_time(layer_fmt, to_file=None, runs=None, subsample=8): |
| plt.ioff() |
| while isinstance(layer_fmt, str) or isinstance(layer_fmt[0], str): |
| layer_fmt = [layer_fmt] |
|
|
| if layer_fmt[0][0].format(0) not in datas[0]["data"].keys(): |
| return |
|
|
| if runs is None: |
| runs = list(DATAS_PER_FILE.keys()) |
| LAYERS = list(range(NUM_LAYERS)) |
| nrows = len(layer_fmt) |
| ncols = len(layer_fmt[0]) |
| fig, axs = plt.subplots( |
| nrows=nrows, |
| ncols=ncols, |
| squeeze=False, |
| sharex=True, |
| figsize=[min(6 * ncols, 14), min(5 * nrows, 8)], |
| layout="compressed", |
| ) |
| timesteps = range(0, len(datas), subsample) |
| plotters = {} |
| for i in range(nrows): |
| for j in range(ncols): |
| print(layer_fmt[i][j]) |
| plotters[(i, j)] = [ |
| Plotter( |
| axs[i, j], |
| run, |
| [layer_fmt[i][j].format(layer) for layer in LAYERS], |
| color, |
| timesteps, |
| ) |
| for run, color in zip(runs, COLORS) |
| ] |
| minimum = min(p.minimum for p in plotters[(i, j)]) |
| maximum = max(p.maximum for p in plotters[(i, j)]) |
| axs[i, j].set_ylim([minimum, maximum]) |
|
|
| def animate(t): |
| out = [] |
| axs[0, 0].legend(loc="upper right") |
| for k, v in plotters.items(): |
| i, j = k |
| axs[i, j].set_title(layer_fmt[i][j]) |
| if i == nrows - 1: |
| axs[i, j].set_xlabel("depth") |
| for p in v: |
| out.append(p.animate(t)) |
| return out |
|
|
| ani = animation.FuncAnimation( |
| fig, animate, frames=len(timesteps), interval=500, blit=True |
| ) |
| if to_file is not None: |
| print("Writing to", to_file) |
| Path(to_file).write_text(ani.to_jshtml()) |
|
|
|
|
| OUTPUT_FOLDER_NAME = "_AND_".join([Path(f).name for f in args.probe_folders]) |
| RENDER_OUT_PATH = (Path("render_out") / OUTPUT_FOLDER_NAME).absolute() |
| RENDER_OUT_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def _render_attn(): |
| to_file = RENDER_OUT_PATH / "attention.html" |
| plot_depth_distr_time( |
| [ |
| [ |
| "FSDPTransformer.layers.{}.attention::attn_logits", |
| "FSDPTransformer.layers.{}.attention::attn_entropy", |
| "FSDPTransformer.layers.{}.attention.wo::in", |
| ], |
| [ |
| "FSDPTransformer.layers.{}.attention.wq::out", |
| "FSDPTransformer.layers.{}.attention.wk::out", |
| "FSDPTransformer.layers.{}.attention.wv::out", |
| ], |
| ], |
| to_file=to_file, |
| subsample=1, |
| ) |
|
|
|
|
| def _render_res(): |
| print("## RESIDUAL") |
| to_file = RENDER_OUT_PATH / "residual.html" |
| plot_depth_distr_time( |
| [ |
| [ |
| |
| |
| "FSDPTransformer.layers.{}.feed_forward.w3::out", |
| "FSDPTransformer.layers.{}.attention.wo::out", |
| ], |
| [ |
| "FSDPTransformer.layers.{}::out", |
| "FSDPTransformer.layers.{}::out.g", |
| ], |
| ], |
| to_file=to_file, |
| subsample=1, |
| ) |
|
|
|
|
| def _render_to_file(linear_layer: str): |
| if linear_layer == "__res__": |
| return _render_res() |
| if linear_layer == "__attn__": |
| return _render_attn() |
| if f"{linear_layer.format(0)}::out" not in datas[0]["data"].keys(): |
| return |
| print("## ", linear_layer) |
| to_file = linear_layer.split("{}", 1)[-1].replace(".", "") + ".html" |
| to_file = RENDER_OUT_PATH / to_file |
| plot_depth_distr_time( |
| [ |
| [f"{linear_layer}::{suffix}" for suffix in ["in", "w", "out"]], |
| [f"{linear_layer}::{suffix}.g" for suffix in ["in", "w", "out"]], |
| ], |
| to_file=to_file, |
| subsample=1, |
| ) |
|
|
|
|
| with Pool(10) as p: |
| p.map( |
| _render_to_file, |
| [ |
| "__attn__", |
| |
| "FSDP.module.blocks.{}.mlp.fc1", |
| "FSDP.module.blocks.{}.mlp.fc2", |
| "FSDP.module.blocks.{}.mlp.qkv", |
| "FSDP.module.blocks.{}.mlp.proj", |
| |
| "FSDPTransformer.layers.{}.attention.wq", |
| "FSDPTransformer.layers.{}.attention.wk", |
| "FSDPTransformer.layers.{}.attention.wv", |
| "FSDPTransformer.layers.{}.attention.wo", |
| "FSDPTransformer.layers.{}.feed_forward.w1", |
| "FSDPTransformer.layers.{}.feed_forward.w2", |
| "FSDPTransformer.layers.{}.feed_forward.w3", |
| "__res__", |
| ], |
| ) |
|
|