Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import io | |
import json | |
import logging | |
import math | |
import os | |
import pathlib | |
import random | |
import beartype | |
import einops.layers.torch | |
import gradio as gr | |
import numpy as np | |
import open_clip | |
import requests | |
import saev.nn | |
import torch | |
from jaxtyping import Float, jaxtyped | |
from PIL import Image, ImageDraw | |
from torch import Tensor | |
from torchvision.transforms import v2 | |
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
logging.basicConfig(level=logging.INFO, format=log_format) | |
logger = logging.getLogger("app.py") | |
#################### | |
# Global Constants # | |
#################### | |
DEBUG = True | |
"""Whether we are debugging.""" | |
n_sae_latents = 3 | |
"""Number of SAE latents to show.""" | |
n_sae_examples = 4 | |
"""Number of SAE examples per latent to show.""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
"""Hardware accelerator, if any.""" | |
vit_ckpt = "ViT-B-16/openai" | |
"""CLIP checkpoint.""" | |
n_patches_per_img: int = 196 | |
"""Number of patches per image in vit_ckpt.""" | |
max_frequency = 1e-2 | |
"""Maximum frequency. Any feature that fires more than this is ignored.""" | |
CWD = pathlib.Path(__file__).parent | |
r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/" | |
logger.info("Set global constants.") | |
########### | |
# Helpers # | |
########### | |
def get_cache_dir() -> str: | |
""" | |
Get cache directory from environment variables, defaulting to the current working directory (.) | |
Returns: | |
A path to a cache directory (might not exist yet). | |
""" | |
cache_dir = "" | |
for var in ("HF_HOME", "HF_HUB_CACHE"): | |
cache_dir = cache_dir or os.environ.get(var, "") | |
return cache_dir or "." | |
def load_model(fpath: str | pathlib.Path, *, device: str = "cpu") -> torch.nn.Module: | |
""" | |
Loads a linear layer from disk. | |
""" | |
with open(fpath, "rb") as fd: | |
kwargs = json.loads(fd.readline().decode()) | |
buffer = io.BytesIO(fd.read()) | |
model = torch.nn.Linear(**kwargs) | |
state_dict = torch.load(buffer, weights_only=True, map_location=device) | |
model.load_state_dict(state_dict) | |
model = model.to(device) | |
return model | |
def get_dataset_img(i: int) -> Image.Image: | |
return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw) | |
def make_img( | |
img: Image.Image, patches: Float[Tensor, ""], *, upper: float | None = None | |
) -> Image.Image: | |
# Resize to 256x256 and crop to 224x224 | |
resize_size_px = (512, 512) | |
resize_w_px, resize_h_px = resize_size_px | |
crop_size_px = (448, 448) | |
crop_w_px, crop_h_px = crop_size_px | |
crop_coords_px = ( | |
(resize_w_px - crop_w_px) // 2, | |
(resize_h_px - crop_h_px) // 2, | |
(resize_w_px + crop_w_px) // 2, | |
(resize_h_px + crop_h_px) // 2, | |
) | |
img = img.resize(resize_size_px).crop(crop_coords_px) | |
img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5) | |
return img | |
########## | |
# Models # | |
########## | |
class SplitClip(torch.nn.Module): | |
def __init__(self, *, n_end_layers: int): | |
super().__init__() | |
if vit_ckpt.startswith("hf-hub:"): | |
clip, _ = open_clip.create_model_from_pretrained( | |
vit_ckpt, cache_dir=get_cache_dir() | |
) | |
else: | |
arch, ckpt = vit_ckpt.split("/") | |
clip, _ = open_clip.create_model_from_pretrained( | |
arch, pretrained=ckpt, cache_dir=get_cache_dir() | |
) | |
model = clip.visual | |
model.proj = None | |
model.output_tokens = True # type: ignore | |
self.vit = model.eval() | |
assert not isinstance(self.vit, open_clip.timm_model.TimmModel) | |
self.n_end_layers = n_end_layers | |
def _expand_token(token, batch_size: int): | |
return token.view(1, 1, -1).expand(batch_size, -1, -1) | |
def forward_start(self, x: Float[Tensor, "batch channels width height"]): | |
x = self.vit.conv1(x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
# class embeddings and positional embeddings | |
x = torch.cat( | |
[self._expand_token(self.vit.class_embedding, x.shape[0]).to(x.dtype), x], | |
dim=1, | |
) | |
# shape = [*, grid ** 2 + 1, width] | |
x = x + self.vit.positional_embedding.to(x.dtype) | |
x = self.vit.patch_dropout(x) | |
x = self.vit.ln_pre(x) | |
for r in self.vit.transformer.resblocks[: -self.n_end_layers]: | |
x = r(x) | |
return x | |
def forward_end(self, x: Float[Tensor, "batch n_patches dim"]): | |
for r in self.vit.transformer.resblocks[-self.n_end_layers :]: | |
x = r(x) | |
x = self.vit.ln_post(x) | |
pooled, _ = self.vit._global_pool(x) | |
if self.vit.proj is not None: | |
pooled = pooled @ self.vit.proj | |
return pooled | |
# ViT | |
split_vit = SplitClip(n_end_layers=1) | |
split_vit = split_vit.to(device) | |
logger.info("Initialized CLIP ViT.") | |
# Linear classifier | |
clf_ckpt_fpath = CWD / "ckpts" / "clf.pt" | |
clf = load_model(clf_ckpt_fpath) | |
clf = clf.to(device).eval() | |
logger.info("Loaded linear classifier.") | |
# SAE | |
sae_ckpt_fpath = CWD / "ckpts" / "sae.pt" | |
sae = saev.nn.load(sae_ckpt_fpath.as_posix()) | |
sae.to(device).eval() | |
logger.info("Loaded SAE.") | |
############ | |
# Datasets # | |
############ | |
human_transform = v2.Compose([ | |
v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), | |
v2.CenterCrop((448, 448)), | |
v2.ToImage(), | |
einops.layers.torch.Rearrange("channels width height -> width height channels"), | |
]) | |
arch, ckpt = vit_ckpt.split("/") | |
_, vit_transform = open_clip.create_model_from_pretrained( | |
arch, pretrained=ckpt, cache_dir=get_cache_dir() | |
) | |
with open(CWD / "data" / "image_fpaths.json") as fd: | |
image_fpaths = json.load(fd) | |
with open(CWD / "data" / "image_labels.json") as fd: | |
image_labels = json.load(fd) | |
# TODO: | |
# This dataset needs to be the CUB2011 dataset. But that means we need to calculate top_img_i based on CUB2011, not on iNat21 train-mini. | |
# examples_dataset = saev.activations.ImageFolder( | |
# "/research/nfs_su_809/workspace/stevens.994/datasets/inat21/train_mini", | |
# transform=v2.Compose([ | |
# v2.Resize(size=(512, 512)), | |
# v2.CenterCrop(size=(448, 448)), | |
# ]), | |
# ) | |
logger.info("Loaded all datasets.") | |
############# | |
# Variables # | |
############# | |
def load_tensor(path: str | pathlib.Path) -> Tensor: | |
return torch.load(path, weights_only=True, map_location="cpu") | |
top_img_i = load_tensor(CWD / "data" / "top_img_i.pt") | |
top_values = load_tensor(CWD / "data" / "top_values.pt") | |
sparsity = load_tensor(CWD / "data" / "sparsity.pt") | |
mask = torch.ones((sae.cfg.d_sae), dtype=bool) | |
mask = mask & (sparsity < max_frequency) | |
############# | |
# Inference # | |
############# | |
def get_image(image_i: int) -> list[Image.Image | int]: | |
image = get_dataset_img(image_i) | |
image = human_transform(image) | |
return [Image.fromarray(image.numpy()), image_labels[image_i]] | |
def get_random_class_image(cls: int) -> Image.Image: | |
indices = [i for i, tgt in enumerate(image_labels) if tgt == cls] | |
i = random.choice(indices) | |
image = get_dataset_img(i) | |
image = human_transform(image) | |
return Image.fromarray(image.numpy()) | |
def get_sae_examples( | |
image_i: int, patches: list[int] | |
) -> list[None | Image.Image | int]: | |
""" | |
Given a particular cell, returns some highlighted images showing what feature fires most on this cell. | |
""" | |
if not patches: | |
return [None] * 12 + [-1] * 3 | |
img = get_dataset_img(image_i) | |
x = vit_transform(img)[None, ...].to(device) | |
x_BPD = split_vit.forward_start(x) | |
vit_acts_MD = x_BPD[0, patches].to(device) | |
_, f_x_MS, _ = sae(vit_acts_MD) | |
f_x_S = f_x_MS.sum(axis=0) | |
latents = torch.argsort(f_x_S, descending=True).cpu() | |
latents = latents[mask[latents]][:n_sae_latents].tolist() | |
images = [] | |
for latent in latents: | |
img_patch_pairs, seen_i_im = [], set() | |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): | |
if i_im in seen_i_im: | |
continue | |
# example = examples_dataset[i_im] | |
example = None | |
img_patch_pairs.append((example["image"], values_p)) | |
seen_i_im.add(i_im) | |
# How to scale values. | |
upper = None | |
if top_values[latent].numel() > 0: | |
upper = top_values[latent].max().item() | |
latent_images = [ | |
make_img(img, patches, upper=upper) | |
for img, patches in img_patch_pairs[:n_sae_examples] | |
] | |
while len(latent_images) < n_sae_examples: | |
latent_images += [None] | |
images.extend(latent_images) | |
return images + latents | |
def get_pred_dist(i: int) -> dict[int, float]: | |
img = get_dataset_img(i) | |
x = vit_transform(img)[None, ...].to(device) | |
x_BPD = split_vit.forward_start(x) | |
x_BD = split_vit.forward_end(x_BPD) | |
logits_BC = clf(x_BD) | |
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() | |
return {i: prob for i, prob in enumerate(probs)} | |
def get_modified_dist( | |
image_i: int, | |
patches: list[int], | |
latent1: int, | |
latent2: int, | |
latent3: int, | |
value1: float, | |
value2: float, | |
value3: float, | |
) -> dict[int, float]: | |
img = get_dataset_img(image_i) | |
x = vit_transform(img)[None, ...].to(device) | |
x_BPD = split_vit.forward_start(x) | |
cls_B1D, x_BPD = x_BPD[:, :1, :], x_BPD[:, 1:, :] | |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD) | |
err_BPD = x_BPD - x_hat_BPD | |
values = torch.tensor( | |
[ | |
unscaled(float(value), top_values[latent].max().item()) | |
for value, latent in [ | |
(value1, latent1), | |
(value2, latent2), | |
(value3, latent3), | |
] | |
], | |
device=device, | |
) | |
patches = torch.tensor(patches, device=device) | |
latents = torch.tensor([latent1, latent2, latent3], device=device) | |
f_x_BPS[:, patches[:, None], latents[None, :]] = values | |
# Reproduce the SAE forward pass after f_x | |
modified_x_hat_BPD = ( | |
einops.einsum( | |
f_x_BPS, | |
sae.W_dec, | |
"batch patches d_sae, d_sae d_vit -> batch patches d_vit", | |
) | |
+ sae.b_dec | |
) | |
modified_BPD = torch.cat([cls_B1D, err_BPD + modified_x_hat_BPD], axis=1) | |
modified_BD = split_vit.forward_end(modified_BPD) | |
logits_BC = clf(modified_BD) | |
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() | |
return {i: prob for i, prob in enumerate(probs)} | |
def unscaled(x: float, max_obs: float) -> float: | |
"""Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs].""" | |
return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs)) | |
def map_range( | |
x: float, | |
domain: tuple[float | int, float | int], | |
range: tuple[float | int, float | int], | |
): | |
a, b = domain | |
c, d = range | |
if not (a <= x <= b): | |
raise ValueError(f"x={x:.3f} must be in {[a, b]}.") | |
return c + (x - a) * (d - c) / (b - a) | |
def add_highlights( | |
img: Image.Image, | |
patches: Float[np.ndarray, " n_patches"], | |
*, | |
upper: float | None = None, | |
opacity: float = 0.9, | |
) -> Image.Image: | |
if not len(patches): | |
return img | |
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) | |
iw_px, ih_px = img.size | |
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np | |
assert iw_np * ih_np == len(patches) | |
# Create a transparent overlay | |
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) | |
draw = ImageDraw.Draw(overlay) | |
# Using semi-transparent red (255, 0, 0, alpha) | |
for p, val in enumerate(patches): | |
assert upper is not None | |
val /= upper + 1e-9 | |
x_np, y_np = p % iw_np, p // ih_np | |
draw.rectangle( | |
[ | |
(x_np * pw_px, y_np * ph_px), | |
(x_np * pw_px + pw_px, y_np * ph_px + ph_px), | |
], | |
fill=(int(val * 256), 0, 0, int(opacity * val * 256)), | |
) | |
# Composite the original image and the overlay | |
return Image.alpha_composite(img.convert("RGBA"), overlay) | |
############# | |
# Interface # | |
############# | |
with gr.Blocks() as demo: | |
image_number = gr.Number(label="Test Example", precision=0) | |
class_number = gr.Number(label="Test Class", precision=0) | |
input_image = gr.Image(label="Input Image") | |
get_input_image_btn = gr.Button(value="Get Input Image") | |
get_input_image_btn.click( | |
get_image, | |
inputs=[image_number], | |
outputs=[input_image, class_number], | |
api_name="get-image", | |
) | |
get_random_class_image_btn = gr.Button(value="Get Random Class Image") | |
get_input_image_btn.click( | |
get_random_class_image, | |
inputs=[image_number], | |
outputs=[input_image], | |
api_name="get-random-class-image", | |
) | |
patch_numbers = gr.CheckboxGroup( | |
label="Image Patch", choices=list(range(n_patches_per_img)) | |
) | |
top_latent_numbers = gr.CheckboxGroup(label="Top Latents") | |
top_latent_numbers = [ | |
gr.Number(label=f"Top Latents #{j + 1}", precision=0) | |
for j in range(n_sae_latents) | |
] | |
sae_example_images = [ | |
gr.Image(label=f"Latent #{j}, Example #{i + 1}") | |
for i in range(n_sae_examples) | |
for j in range(n_sae_latents) | |
] | |
get_sae_examples_btn = gr.Button(value="Get SAE Examples") | |
get_sae_examples_btn.click( | |
get_sae_examples, | |
inputs=[image_number, patch_numbers], | |
outputs=sae_example_images + top_latent_numbers, | |
api_name="get-sae-examples", | |
) | |
pred_dist = gr.Label(label="Pred. Dist.") | |
get_pred_dist_btn = gr.Button(value="Get Pred. Distribution") | |
get_pred_dist_btn.click( | |
get_pred_dist, | |
inputs=[image_number], | |
outputs=[pred_dist], | |
api_name="get-preds", | |
) | |
latent_numbers = [gr.Number(label=f"Latent {i + 1}", precision=0) for i in range(3)] | |
value_sliders = [ | |
gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3) | |
] | |
get_modified_dist_btn = gr.Button(value="Get Modified Label") | |
get_modified_dist_btn.click( | |
get_modified_dist, | |
inputs=[image_number, patch_numbers] + latent_numbers + value_sliders, | |
outputs=[pred_dist], | |
api_name="get-modified", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |