|
import matplotlib.pyplot as plt |
|
import pydantic |
|
import time |
|
import numpy as np |
|
from tqdm import tqdm, trange |
|
import torch |
|
from torch import nn |
|
from diffusers import StableDiffusionPipeline |
|
import clip |
|
from dreamsim import dreamsim |
|
from ribs.archives import GridArchive |
|
from ribs.schedulers import Scheduler |
|
from ribs.emitters import GaussianEmitter |
|
import itertools |
|
from ribs.visualize import grid_archive_heatmap |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.cuda.empty_cache() |
|
print("Torch device:", DEVICE) |
|
|
|
|
|
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
print("Torch dtype:", TORCH_DTYPE) |
|
IMG_WIDTH = 256 |
|
IMG_HEIGHT = 256 |
|
SD_IN_HEIGHT = 32 |
|
SD_IN_WIDTH = 32 |
|
SD_CHECKPOINT = "lambdalabs/miniSD-diffusers" |
|
|
|
BATCH_SIZE = 4 |
|
SD_IN_CHANNELS = 4 |
|
SD_IN_SHAPE = ( |
|
BATCH_SIZE, |
|
SD_IN_CHANNELS, |
|
SD_IN_HEIGHT, |
|
SD_IN_WIDTH, |
|
) |
|
|
|
SDPIPE = StableDiffusionPipeline.from_pretrained( |
|
SD_CHECKPOINT, |
|
torch_dtype=TORCH_DTYPE, |
|
safety_checker=None, |
|
requires_safety_checker=False, |
|
) |
|
|
|
SDPIPE.set_progress_bar_config(disable=True) |
|
SDPIPE = SDPIPE.to(DEVICE) |
|
|
|
GRID_SIZE = (20, 20) |
|
SEED = 123 |
|
np.random.seed(SEED) |
|
torch.manual_seed(SEED) |
|
|
|
|
|
|
|
|
|
|
|
class DivProj(nn.Module): |
|
def __init__(self, input_dim, latent_dim=2): |
|
super().__init__() |
|
self.proj = nn.Sequential( |
|
nn.Linear(in_features=input_dim, out_features=latent_dim), |
|
) |
|
|
|
def forward(self, x): |
|
"""Get diversity representations.""" |
|
x = self.proj(x) |
|
return x |
|
|
|
def calc_dis(self, x1, x2): |
|
"""Calculate diversity distance as (squared) L2 distance.""" |
|
x1 = self.forward(x1) |
|
x2 = self.forward(x2) |
|
return torch.sum(torch.square(x1 - x2), -1) |
|
|
|
def triplet_delta_dis(self, ref, x1, x2): |
|
"""Calculate delta distance comparing x1 and x2 to ref.""" |
|
x1 = self.forward(x1) |
|
x2 = self.forward(x2) |
|
ref = self.forward(ref) |
|
return (torch.sum(torch.square(ref - x1), -1) - |
|
torch.sum(torch.square(ref - x2), -1)) |
|
|
|
|
|
|
|
|
|
loss_fn = lambda y, delta_dis: torch.max( |
|
torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis |
|
).mean() |
|
|
|
|
|
def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32): |
|
"""Trains the DivProj model on ground-truth labels.""" |
|
t = time.time() |
|
model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim) |
|
model.to(DEVICE) |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
n_pref_data = inputs.shape[0] |
|
ref = inputs[:, 0] |
|
x1 = inputs[:, 1] |
|
x2 = inputs[:, 2] |
|
|
|
n_train = int(n_pref_data * 0.75) |
|
n_val = n_pref_data - n_train |
|
|
|
|
|
ref_train = ref[:n_train] |
|
x1_train = x1[:n_train] |
|
x2_train = x2[:n_train] |
|
ref_val = ref[n_train:] |
|
x1_val = x1[n_train:] |
|
x2_val = x2[n_train:] |
|
|
|
|
|
ref_dreamsim_features = dreamsim_features[:, 0] |
|
x1_dreamsim_features = dreamsim_features[:, 1] |
|
x2_dreamsim_features = dreamsim_features[:, 2] |
|
ref_gt_train = ref_dreamsim_features[:n_train] |
|
x1_gt_train = x1_dreamsim_features[:n_train] |
|
x2_gt_train = x2_dreamsim_features[:n_train] |
|
ref_gt_val = ref_dreamsim_features[n_train:] |
|
x1_gt_val = x1_dreamsim_features[n_train:] |
|
x2_gt_val = x2_dreamsim_features[n_train:] |
|
|
|
val_acc = [] |
|
n_iters_per_epoch = max((n_train) // batch_size, 1) |
|
for epoch in range(200): |
|
for _ in range(n_iters_per_epoch): |
|
optimizer.zero_grad() |
|
|
|
idx = np.random.choice(n_train, batch_size) |
|
batch_ref = ref_train[idx].float() |
|
batch1 = x1_train[idx].float() |
|
batch2 = x2_train[idx].float() |
|
|
|
|
|
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) |
|
|
|
|
|
gt_dis = torch.nn.functional.cosine_similarity( |
|
ref_gt_train[idx], x2_gt_train[idx], dim=-1 |
|
) - torch.nn.functional.cosine_similarity( |
|
ref_gt_train[idx], x1_gt_train[idx], dim=-1 |
|
) |
|
gt = (gt_dis > 0).to(TORCH_DTYPE) |
|
|
|
loss = loss_fn(gt, delta_dis) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
n_correct = 0 |
|
n_total = 0 |
|
with torch.no_grad(): |
|
idx = np.arange(n_val) |
|
batch_ref = ref_val[idx].float() |
|
batch1 = x1_val[idx].float() |
|
batch2 = x2_val[idx].float() |
|
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) |
|
pred = delta_dis > 0 |
|
gt_dis = torch.nn.functional.cosine_similarity( |
|
ref_gt_val[idx], x2_gt_val[idx], dim=-1 |
|
) - torch.nn.functional.cosine_similarity( |
|
ref_gt_val[idx], x1_gt_val[idx], dim=-1 |
|
) |
|
gt = gt_dis > 0 |
|
n_correct += (pred == gt).sum().item() |
|
n_total += len(idx) |
|
|
|
acc = n_correct / n_total |
|
val_acc.append(acc) |
|
|
|
|
|
if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]): |
|
break |
|
|
|
print( |
|
f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}" |
|
) |
|
|
|
return model.to(TORCH_DTYPE), acc |
|
|
|
|
|
def compute_diversity_measures(clip_features, diversity_model): |
|
with torch.no_grad(): |
|
measures = diversity_model(clip_features).detach().cpu().numpy() |
|
return measures |
|
|
|
|
|
def tensor_to_list(tensor): |
|
sols = tensor.detach().cpu().numpy().astype(np.float32) |
|
return sols.reshape(sols.shape[0], -1) |
|
|
|
|
|
def list_to_tensor(list_): |
|
sols = np.array(list_).reshape( |
|
len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH |
|
) |
|
return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE) |
|
|
|
|
|
def create_scheduler( |
|
sols, |
|
objs, |
|
clip_features, |
|
diversity_model, |
|
seed=None, |
|
): |
|
measures = compute_diversity_measures(clip_features, diversity_model) |
|
archive_bounds = np.array( |
|
[np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)] |
|
).T |
|
|
|
sols = tensor_to_list(sols) |
|
|
|
|
|
archive = GridArchive( |
|
solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED |
|
) |
|
|
|
|
|
archive.add(sols, objs, measures) |
|
|
|
|
|
emitters = [ |
|
GaussianEmitter( |
|
archive=archive, |
|
sigma=0.1, |
|
initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"], |
|
batch_size=BATCH_SIZE, |
|
seed=SEED, |
|
) |
|
] |
|
|
|
|
|
return archive, Scheduler(archive, emitters) |
|
|
|
|
|
def plot_archive(archive): |
|
plt.figure(figsize=(6, 4.5)) |
|
grid_archive_heatmap(archive, vmin=0, vmax=100) |
|
plt.xlabel("Diversity Metric 1") |
|
plt.ylabel("Diversity Metric 2") |
|
return plt |
|
|
|
|
|
def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200): |
|
INIT_POP = init_pop |
|
TOTAL_ITRS = total_itrs |
|
|
|
|
|
CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE) |
|
CLIP_MODEL.eval() |
|
for p in CLIP_MODEL.parameters(): |
|
p.requires_grad_(False) |
|
|
|
def compute_clip_scores(imgs, text, return_clip_features=False): |
|
"""Computes CLIP scores for a batch of images and a given text prompt.""" |
|
img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE) |
|
tokenized_text = clip.tokenize([text]).to(DEVICE) |
|
img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text) |
|
img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0] |
|
img_logits = 1 / img_logits * 100 |
|
|
|
img_logits = (10.0 - img_logits) * 10.0 |
|
|
|
if return_clip_features: |
|
clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE) |
|
return img_logits, clip_features |
|
else: |
|
return img_logits |
|
|
|
DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim( |
|
pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE |
|
) |
|
|
|
def evaluate_lsi( |
|
latents, |
|
prompt, |
|
return_features=False, |
|
diversity_model=None, |
|
): |
|
"""Evaluates the objective of LSI for a batch of latents and a given text prompt.""" |
|
|
|
images = SDPIPE( |
|
prompt, |
|
num_images_per_prompt=latents.shape[0], |
|
latents=latents, |
|
|
|
).images |
|
|
|
objs, clip_features = compute_clip_scores( |
|
images, |
|
prompt, |
|
return_clip_features=True, |
|
) |
|
|
|
images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE) |
|
dreamsim_features = DREAMSIM_MODEL.embed(images) |
|
|
|
if diversity_model is not None: |
|
measures = compute_diversity_measures(clip_features, diversity_model) |
|
else: |
|
measures = None |
|
|
|
if return_features: |
|
return objs, measures, clip_features, dreamsim_features |
|
else: |
|
return objs, measures |
|
|
|
|
|
update_schedule = [1, 21, 51, 101] |
|
n_pref_data = 1000 |
|
|
|
archive = None |
|
|
|
best = 0.0 |
|
for itr in trange(1, TOTAL_ITRS + 1): |
|
|
|
if itr in update_schedule: |
|
if archive is None: |
|
tqdm.write("Initializing archive and diversity projection.") |
|
|
|
all_sols = [] |
|
all_clip_features = [] |
|
all_dreamsim_features = [] |
|
all_objs = [] |
|
|
|
|
|
n_batches = INIT_POP // BATCH_SIZE |
|
for _ in range(n_batches): |
|
sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE) |
|
objs, _, clip_features, dreamsim_features = evaluate_lsi( |
|
sols, prompt, return_features=True |
|
) |
|
all_sols.append(sols) |
|
all_clip_features.append(clip_features) |
|
all_dreamsim_features.append(dreamsim_features) |
|
all_objs.append(objs) |
|
all_sols = torch.concat(all_sols, dim=0) |
|
all_clip_features = torch.concat(all_clip_features, dim=0) |
|
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) |
|
all_objs = np.concatenate(all_objs, axis=0) |
|
|
|
|
|
div_proj_data = [] |
|
div_proj_labels = [] |
|
for _ in range(n_pref_data): |
|
idx = np.random.choice(all_sols.shape[0], 3) |
|
div_proj_data.append(all_clip_features[idx]) |
|
div_proj_labels.append(all_dreamsim_features[idx]) |
|
div_proj_data = torch.concat(div_proj_data, dim=0) |
|
div_proj_labels = torch.concat(div_proj_labels, dim=0) |
|
div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1) |
|
div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1) |
|
diversity_model, div_proj_acc = fit_div_proj( |
|
div_proj_data, |
|
div_proj_label, |
|
latent_dim=2, |
|
) |
|
|
|
else: |
|
tqdm.write("Updating archive and diversity projection.") |
|
|
|
|
|
all_sols = list_to_tensor(archive.data("solution")) |
|
n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int) |
|
all_clip_features = [] |
|
all_dreamsim_features = [] |
|
all_objs = [] |
|
for i in range(n_batches): |
|
sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] |
|
objs, _, clip_features, dreamsim_features = evaluate_lsi( |
|
sols, prompt, return_features=True |
|
) |
|
all_clip_features.append(clip_features) |
|
all_dreamsim_features.append(dreamsim_features) |
|
all_objs.append(objs) |
|
all_clip_features = torch.concat( |
|
all_clip_features, dim=0 |
|
) |
|
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) |
|
all_objs = np.concatenate(all_objs, axis=0) |
|
|
|
|
|
additional_features = [] |
|
additional_labels = [] |
|
for _ in range(n_pref_data): |
|
idx = np.random.choice(all_sols.shape[0], 3) |
|
additional_features.append(all_clip_features[idx]) |
|
additional_labels.append(all_dreamsim_features[idx]) |
|
additional_features = torch.concat(additional_features, dim=0) |
|
additional_labels = torch.concat(additional_labels, dim=0) |
|
additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1) |
|
additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1) |
|
div_proj_data = torch.concat( |
|
(div_proj_data, additional_div_proj_data), axis=0 |
|
) |
|
div_proj_label = torch.concat( |
|
(div_proj_label, additional_div_proj_label), axis=0 |
|
) |
|
diversity_model, div_proj_acc = fit_div_proj( |
|
div_proj_data, |
|
div_proj_label, |
|
latent_dim=2, |
|
) |
|
|
|
archive, scheduler = create_scheduler( |
|
all_sols, |
|
all_objs, |
|
all_clip_features, |
|
diversity_model, |
|
seed=SEED, |
|
) |
|
|
|
|
|
sols = scheduler.ask() |
|
sols = list_to_tensor(sols) |
|
objs, measures, clip_features, dreamsim_features = evaluate_lsi( |
|
sols, prompt, return_features=True, diversity_model=diversity_model |
|
) |
|
best = max(best, max(objs)) |
|
scheduler.tell(objs, measures) |
|
|
|
|
|
|
|
final_itr = itr == TOTAL_ITRS |
|
|
|
|
|
qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage |
|
|
|
tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}") |
|
|
|
plt = plot_archive(archive) |
|
yield archive, plt |
|
|
|
plt = plot_archive(archive) |
|
return archive, plt |
|
|
|
|
|
def many_pictures(archive, prompt:str): |
|
|
|
img_freq = ( |
|
4, |
|
4, |
|
) |
|
|
|
|
|
imgs = [] |
|
|
|
|
|
df = archive.data(return_type="pandas") |
|
|
|
|
|
measure_bounds = np.array( |
|
[ |
|
(df["measures_0"].min(), df["measures_0"].max()), |
|
(df["measures_1"].min(), df["measures_1"].max()), |
|
] |
|
) |
|
|
|
archive_bounds = np.array( |
|
[archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]] |
|
) |
|
|
|
|
|
delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0] |
|
delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1] |
|
|
|
|
|
for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])): |
|
|
|
measures_0_low = archive_bounds[0][0] + delta_measures_0 * row |
|
measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1) |
|
measures_1_low = archive_bounds[1][0] + delta_measures_1 * col |
|
measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1) |
|
|
|
if row == 0: |
|
measures_0_low = measure_bounds[0][0] |
|
if col == 0: |
|
measures_1_low = measure_bounds[1][0] |
|
if row == img_freq[0] - 1: |
|
measures_0_high = measure_bounds[0][1] |
|
if col == img_freq[1] - 1: |
|
measures_0_high = measure_bounds[1][1] |
|
|
|
|
|
query_string = ( |
|
f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & " |
|
f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}" |
|
) |
|
df_box = df.query(query_string) |
|
|
|
if not df_box.empty: |
|
|
|
|
|
|
|
sol = ( |
|
df_box.loc[ |
|
:, |
|
"solution_0" : "solution_{}".format( |
|
SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1 |
|
), |
|
] |
|
.sample(n=1) |
|
.iloc[0] |
|
) |
|
|
|
|
|
latents = torch.tensor(sol.to_numpy()).reshape( |
|
(1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH) |
|
) |
|
latents = latents.to(TORCH_DTYPE).to(DEVICE) |
|
img = SDPIPE( |
|
prompt, |
|
num_images_per_prompt=1, |
|
latents=latents, |
|
|
|
).images[0] |
|
|
|
img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0 |
|
imgs.append(img) |
|
else: |
|
imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH))) |
|
from torchvision.utils import make_grid |
|
|
|
|
|
def create_archive_tick_labels(measure_range, num_ticks): |
|
delta = (measure_range[1] - measure_range[0]) / num_ticks |
|
ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)] |
|
return ticklabels |
|
|
|
|
|
plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2)) |
|
img_grid = make_grid(imgs, nrow=img_freq[0], padding=0) |
|
img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) |
|
plt.imshow(img_grid) |
|
|
|
plt.xlabel("") |
|
num_x_ticks = img_freq[0] |
|
x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks) |
|
x_tick_range = img_grid.shape[1] |
|
x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks) |
|
plt.xticks(x_ticks, x_ticklabels) |
|
|
|
plt.ylabel("") |
|
num_y_ticks = img_freq[1] |
|
y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks) |
|
y_ticklabels.reverse() |
|
y_tick_range = img_grid.shape[0] |
|
y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks) |
|
plt.yticks(y_ticks, y_ticklabels) |
|
plt.tight_layout() |
|
|
|
return plt |
|
|