jennyzzt's picture
init app
59da1c6
import matplotlib.pyplot as plt
import pydantic
import time
import numpy as np
from tqdm import tqdm, trange
import torch
from torch import nn
from diffusers import StableDiffusionPipeline
import clip
from dreamsim import dreamsim
from ribs.archives import GridArchive
from ribs.schedulers import Scheduler
from ribs.emitters import GaussianEmitter
import itertools
from ribs.visualize import grid_archive_heatmap
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print("Torch device:", DEVICE)
# Use float16 for GPU, float32 for CPU.
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
print("Torch dtype:", TORCH_DTYPE)
IMG_WIDTH = 256
IMG_HEIGHT = 256
SD_IN_HEIGHT = 32
SD_IN_WIDTH = 32
SD_CHECKPOINT = "lambdalabs/miniSD-diffusers"
BATCH_SIZE = 4
SD_IN_CHANNELS = 4
SD_IN_SHAPE = (
BATCH_SIZE,
SD_IN_CHANNELS,
SD_IN_HEIGHT,
SD_IN_WIDTH,
)
SDPIPE = StableDiffusionPipeline.from_pretrained(
SD_CHECKPOINT,
torch_dtype=TORCH_DTYPE,
safety_checker=None, # For faster inference.
requires_safety_checker=False,
)
SDPIPE.set_progress_bar_config(disable=True)
SDPIPE = SDPIPE.to(DEVICE)
GRID_SIZE = (20, 20)
SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
# INIT_POP = 200 # Initial population.
# TOTAL_ITRS = 200 # Total number of iterations.
class DivProj(nn.Module):
def __init__(self, input_dim, latent_dim=2):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(in_features=input_dim, out_features=latent_dim),
)
def forward(self, x):
"""Get diversity representations."""
x = self.proj(x)
return x
def calc_dis(self, x1, x2):
"""Calculate diversity distance as (squared) L2 distance."""
x1 = self.forward(x1)
x2 = self.forward(x2)
return torch.sum(torch.square(x1 - x2), -1)
def triplet_delta_dis(self, ref, x1, x2):
"""Calculate delta distance comparing x1 and x2 to ref."""
x1 = self.forward(x1)
x2 = self.forward(x2)
ref = self.forward(ref)
return (torch.sum(torch.square(ref - x1), -1) -
torch.sum(torch.square(ref - x2), -1))
# Triplet loss with margin 0.05.
# The binary preference labels are scaled to y = 1 or -1 for the loss, where y = 1 means x2 is more similar to ref than x1.
loss_fn = lambda y, delta_dis: torch.max(
torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis
).mean()
def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32):
"""Trains the DivProj model on ground-truth labels."""
t = time.time()
model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
n_pref_data = inputs.shape[0]
ref = inputs[:, 0]
x1 = inputs[:, 1]
x2 = inputs[:, 2]
n_train = int(n_pref_data * 0.75)
n_val = n_pref_data - n_train
# Split data into train and val.
ref_train = ref[:n_train]
x1_train = x1[:n_train]
x2_train = x2[:n_train]
ref_val = ref[n_train:]
x1_val = x1[n_train:]
x2_val = x2[n_train:]
# Split DreamSim features into train and val.
ref_dreamsim_features = dreamsim_features[:, 0]
x1_dreamsim_features = dreamsim_features[:, 1]
x2_dreamsim_features = dreamsim_features[:, 2]
ref_gt_train = ref_dreamsim_features[:n_train]
x1_gt_train = x1_dreamsim_features[:n_train]
x2_gt_train = x2_dreamsim_features[:n_train]
ref_gt_val = ref_dreamsim_features[n_train:]
x1_gt_val = x1_dreamsim_features[n_train:]
x2_gt_val = x2_dreamsim_features[n_train:]
val_acc = []
n_iters_per_epoch = max((n_train) // batch_size, 1)
for epoch in range(200):
for _ in range(n_iters_per_epoch):
optimizer.zero_grad()
idx = np.random.choice(n_train, batch_size)
batch_ref = ref_train[idx].float()
batch1 = x1_train[idx].float()
batch2 = x2_train[idx].float()
# Get delta distance from model.
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
# Get preference labels from DreamSim features.
gt_dis = torch.nn.functional.cosine_similarity(
ref_gt_train[idx], x2_gt_train[idx], dim=-1
) - torch.nn.functional.cosine_similarity(
ref_gt_train[idx], x1_gt_train[idx], dim=-1
)
gt = (gt_dis > 0).to(TORCH_DTYPE) # if distance from the two sims are greater than 0, convert gt to torch_type
loss = loss_fn(gt, delta_dis)
loss.backward()
optimizer.step()
# Validate.
n_correct = 0
n_total = 0
with torch.no_grad():
idx = np.arange(n_val)
batch_ref = ref_val[idx].float()
batch1 = x1_val[idx].float()
batch2 = x2_val[idx].float()
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
pred = delta_dis > 0
gt_dis = torch.nn.functional.cosine_similarity(
ref_gt_val[idx], x2_gt_val[idx], dim=-1
) - torch.nn.functional.cosine_similarity(
ref_gt_val[idx], x1_gt_val[idx], dim=-1
)
gt = gt_dis > 0
n_correct += (pred == gt).sum().item()
n_total += len(idx)
acc = n_correct / n_total
val_acc.append(acc)
# Early stopping if val_acc does not improve for 10 epochs.
if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]):
break
print(
f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}"
)
return model.to(TORCH_DTYPE), acc
def compute_diversity_measures(clip_features, diversity_model):
with torch.no_grad():
measures = diversity_model(clip_features).detach().cpu().numpy()
return measures
def tensor_to_list(tensor):
sols = tensor.detach().cpu().numpy().astype(np.float32)
return sols.reshape(sols.shape[0], -1)
def list_to_tensor(list_):
sols = np.array(list_).reshape(
len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH
) # Hard-coded for now.
return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE)
def create_scheduler(
sols,
objs,
clip_features,
diversity_model,
seed=None,
):
measures = compute_diversity_measures(clip_features, diversity_model)
archive_bounds = np.array(
[np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)]
).T
sols = tensor_to_list(sols)
# Set up archive.
archive = GridArchive(
solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED
)
# Add initial solutions to the archive.
archive.add(sols, objs, measures)
# Set up the GaussianEmitter.
emitters = [
GaussianEmitter(
archive=archive,
sigma=0.1,
initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"],
batch_size=BATCH_SIZE,
seed=SEED,
)
]
# Return the archive and scheduler.
return archive, Scheduler(archive, emitters)
def plot_archive(archive):
plt.figure(figsize=(6, 4.5))
grid_archive_heatmap(archive, vmin=0, vmax=100)
plt.xlabel("Diversity Metric 1")
plt.ylabel("Diversity Metric 2")
return plt
def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200):
INIT_POP = init_pop
TOTAL_ITRS = total_itrs
# This tutorial uses ViT-B/32, you may use other checkpoints depending on your resources and need.
CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE)
CLIP_MODEL.eval()
for p in CLIP_MODEL.parameters():
p.requires_grad_(False)
def compute_clip_scores(imgs, text, return_clip_features=False):
"""Computes CLIP scores for a batch of images and a given text prompt."""
img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE)
tokenized_text = clip.tokenize([text]).to(DEVICE)
img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text)
img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0]
img_logits = 1 / img_logits * 100
# Remap the objective from minimizing [0, 10] to maximizing [0, 100]
img_logits = (10.0 - img_logits) * 10.0
if return_clip_features:
clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE)
return img_logits, clip_features
else:
return img_logits
DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim(
pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE
)
def evaluate_lsi(
latents,
prompt,
return_features=False,
diversity_model=None,
):
"""Evaluates the objective of LSI for a batch of latents and a given text prompt."""
images = SDPIPE(
prompt,
num_images_per_prompt=latents.shape[0],
latents=latents,
# num_inference_steps=1, # For testing.
).images
objs, clip_features = compute_clip_scores(
images,
prompt,
return_clip_features=True,
)
images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE)
dreamsim_features = DREAMSIM_MODEL.embed(images)
if diversity_model is not None:
measures = compute_diversity_measures(clip_features, diversity_model)
else:
measures = None
if return_features:
return objs, measures, clip_features, dreamsim_features
else:
return objs, measures
update_schedule = [1, 21, 51, 101] # Iterations on which to update the archive.
n_pref_data = 1000 # Number of preferences used in each update.
archive = None
best = 0.0
for itr in trange(1, TOTAL_ITRS + 1):
# Update archive and scheduler if needed.
if itr in update_schedule:
if archive is None:
tqdm.write("Initializing archive and diversity projection.")
all_sols = []
all_clip_features = []
all_dreamsim_features = []
all_objs = []
# Sample random solutions and get judgment on similarity.
n_batches = INIT_POP // BATCH_SIZE
for _ in range(n_batches):
sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE)
objs, _, clip_features, dreamsim_features = evaluate_lsi(
sols, prompt, return_features=True
)
all_sols.append(sols)
all_clip_features.append(clip_features)
all_dreamsim_features.append(dreamsim_features)
all_objs.append(objs)
all_sols = torch.concat(all_sols, dim=0)
all_clip_features = torch.concat(all_clip_features, dim=0)
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
all_objs = np.concatenate(all_objs, axis=0)
# Initialize the diversity projection model.
div_proj_data = []
div_proj_labels = []
for _ in range(n_pref_data):
idx = np.random.choice(all_sols.shape[0], 3)
div_proj_data.append(all_clip_features[idx])
div_proj_labels.append(all_dreamsim_features[idx])
div_proj_data = torch.concat(div_proj_data, dim=0)
div_proj_labels = torch.concat(div_proj_labels, dim=0)
div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1)
div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1)
diversity_model, div_proj_acc = fit_div_proj(
div_proj_data,
div_proj_label,
latent_dim=2,
)
else:
tqdm.write("Updating archive and diversity projection.")
# Get all the current solutions and collect feedback.
all_sols = list_to_tensor(archive.data("solution"))
n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int)
all_clip_features = []
all_dreamsim_features = []
all_objs = []
for i in range(n_batches):
sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]
objs, _, clip_features, dreamsim_features = evaluate_lsi(
sols, prompt, return_features=True
)
all_clip_features.append(clip_features)
all_dreamsim_features.append(dreamsim_features)
all_objs.append(objs)
all_clip_features = torch.concat(
all_clip_features, dim=0
) # n_pref_data * 3, dim
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
all_objs = np.concatenate(all_objs, axis=0)
# Update the diversity projection model.
additional_features = []
additional_labels = []
for _ in range(n_pref_data):
idx = np.random.choice(all_sols.shape[0], 3)
additional_features.append(all_clip_features[idx])
additional_labels.append(all_dreamsim_features[idx])
additional_features = torch.concat(additional_features, dim=0)
additional_labels = torch.concat(additional_labels, dim=0)
additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1)
additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1)
div_proj_data = torch.concat(
(div_proj_data, additional_div_proj_data), axis=0
)
div_proj_label = torch.concat(
(div_proj_label, additional_div_proj_label), axis=0
)
diversity_model, div_proj_acc = fit_div_proj(
div_proj_data,
div_proj_label,
latent_dim=2,
)
archive, scheduler = create_scheduler(
all_sols,
all_objs,
all_clip_features,
diversity_model,
seed=SEED,
)
# Primary QD loop.
sols = scheduler.ask()
sols = list_to_tensor(sols)
objs, measures, clip_features, dreamsim_features = evaluate_lsi(
sols, prompt, return_features=True, diversity_model=diversity_model
)
best = max(best, max(objs))
scheduler.tell(objs, measures)
# This can be used as a flag to save on the final iteration, but note that
# we do not save results in this tutorial.
final_itr = itr == TOTAL_ITRS
# Update the summary statistics for the archive.
qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage
tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}")
plt = plot_archive(archive)
yield archive, plt
plt = plot_archive(archive)
return archive, plt
def many_pictures(archive, prompt:str):
# Modify this to determine how many images to plot along each dimension.
img_freq = (
4, # Number of columns of images.
4, # Number of rows of images.
)
# List of images.
imgs = []
# Convert archive to a df with solutions available.
df = archive.data(return_type="pandas")
# Compute the min and max measures for which solutions were found.
measure_bounds = np.array(
[
(df["measures_0"].min(), df["measures_0"].max()),
(df["measures_1"].min(), df["measures_1"].max()),
]
)
archive_bounds = np.array(
[archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]]
)
delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0]
delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1]
for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])):
# Compute bounds of a box in measure space.
measures_0_low = archive_bounds[0][0] + delta_measures_0 * row
measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1)
measures_1_low = archive_bounds[1][0] + delta_measures_1 * col
measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1)
if row == 0:
measures_0_low = measure_bounds[0][0]
if col == 0:
measures_1_low = measure_bounds[1][0]
if row == img_freq[0] - 1:
measures_0_high = measure_bounds[0][1]
if col == img_freq[1] - 1:
measures_0_high = measure_bounds[1][1]
# Query for a solution with measures within this box.
query_string = (
f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & "
f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}"
)
df_box = df.query(query_string)
if not df_box.empty:
# Randomly sample a solution from the box.
# Stable Diffusion solutions have SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH
# dimensions, so the final solution col is solution_(x-1).
sol = (
df_box.loc[
:,
"solution_0" : "solution_{}".format(
SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1
),
]
.sample(n=1)
.iloc[0]
)
# Convert the latent vector solution to an image.
latents = torch.tensor(sol.to_numpy()).reshape(
(1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH)
)
latents = latents.to(TORCH_DTYPE).to(DEVICE)
img = SDPIPE(
prompt,
num_images_per_prompt=1,
latents=latents,
# num_inference_steps=1, # For testing.
).images[0]
img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0
imgs.append(img)
else:
imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH)))
from torchvision.utils import make_grid
def create_archive_tick_labels(measure_range, num_ticks):
delta = (measure_range[1] - measure_range[0]) / num_ticks
ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)]
return ticklabels
plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2))
img_grid = make_grid(imgs, nrow=img_freq[0], padding=0)
img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0))
plt.imshow(img_grid)
plt.xlabel("")
num_x_ticks = img_freq[0]
x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks)
x_tick_range = img_grid.shape[1]
x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks)
plt.xticks(x_ticks, x_ticklabels)
plt.ylabel("")
num_y_ticks = img_freq[1]
y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks)
y_ticklabels.reverse()
y_tick_range = img_grid.shape[0]
y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks)
plt.yticks(y_ticks, y_ticklabels)
plt.tight_layout()
return plt