import matplotlib.pyplot as plt import pydantic import time import numpy as np from tqdm import tqdm, trange import torch from torch import nn from diffusers import StableDiffusionPipeline import clip from dreamsim import dreamsim from ribs.archives import GridArchive from ribs.schedulers import Scheduler from ribs.emitters import GaussianEmitter import itertools from ribs.visualize import grid_archive_heatmap DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.empty_cache() print("Torch device:", DEVICE) # Use float16 for GPU, float32 for CPU. TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 print("Torch dtype:", TORCH_DTYPE) IMG_WIDTH = 256 IMG_HEIGHT = 256 SD_IN_HEIGHT = 32 SD_IN_WIDTH = 32 SD_CHECKPOINT = "lambdalabs/miniSD-diffusers" BATCH_SIZE = 4 SD_IN_CHANNELS = 4 SD_IN_SHAPE = ( BATCH_SIZE, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH, ) SDPIPE = StableDiffusionPipeline.from_pretrained( SD_CHECKPOINT, torch_dtype=TORCH_DTYPE, safety_checker=None, # For faster inference. requires_safety_checker=False, ) SDPIPE.set_progress_bar_config(disable=True) SDPIPE = SDPIPE.to(DEVICE) GRID_SIZE = (20, 20) SEED = 123 np.random.seed(SEED) torch.manual_seed(SEED) # INIT_POP = 200 # Initial population. # TOTAL_ITRS = 200 # Total number of iterations. class DivProj(nn.Module): def __init__(self, input_dim, latent_dim=2): super().__init__() self.proj = nn.Sequential( nn.Linear(in_features=input_dim, out_features=latent_dim), ) def forward(self, x): """Get diversity representations.""" x = self.proj(x) return x def calc_dis(self, x1, x2): """Calculate diversity distance as (squared) L2 distance.""" x1 = self.forward(x1) x2 = self.forward(x2) return torch.sum(torch.square(x1 - x2), -1) def triplet_delta_dis(self, ref, x1, x2): """Calculate delta distance comparing x1 and x2 to ref.""" x1 = self.forward(x1) x2 = self.forward(x2) ref = self.forward(ref) return (torch.sum(torch.square(ref - x1), -1) - torch.sum(torch.square(ref - x2), -1)) # Triplet loss with margin 0.05. # The binary preference labels are scaled to y = 1 or -1 for the loss, where y = 1 means x2 is more similar to ref than x1. loss_fn = lambda y, delta_dis: torch.max( torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis ).mean() def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32): """Trains the DivProj model on ground-truth labels.""" t = time.time() model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim) model.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) n_pref_data = inputs.shape[0] ref = inputs[:, 0] x1 = inputs[:, 1] x2 = inputs[:, 2] n_train = int(n_pref_data * 0.75) n_val = n_pref_data - n_train # Split data into train and val. ref_train = ref[:n_train] x1_train = x1[:n_train] x2_train = x2[:n_train] ref_val = ref[n_train:] x1_val = x1[n_train:] x2_val = x2[n_train:] # Split DreamSim features into train and val. ref_dreamsim_features = dreamsim_features[:, 0] x1_dreamsim_features = dreamsim_features[:, 1] x2_dreamsim_features = dreamsim_features[:, 2] ref_gt_train = ref_dreamsim_features[:n_train] x1_gt_train = x1_dreamsim_features[:n_train] x2_gt_train = x2_dreamsim_features[:n_train] ref_gt_val = ref_dreamsim_features[n_train:] x1_gt_val = x1_dreamsim_features[n_train:] x2_gt_val = x2_dreamsim_features[n_train:] val_acc = [] n_iters_per_epoch = max((n_train) // batch_size, 1) for epoch in range(200): for _ in range(n_iters_per_epoch): optimizer.zero_grad() idx = np.random.choice(n_train, batch_size) batch_ref = ref_train[idx].float() batch1 = x1_train[idx].float() batch2 = x2_train[idx].float() # Get delta distance from model. delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) # Get preference labels from DreamSim features. gt_dis = torch.nn.functional.cosine_similarity( ref_gt_train[idx], x2_gt_train[idx], dim=-1 ) - torch.nn.functional.cosine_similarity( ref_gt_train[idx], x1_gt_train[idx], dim=-1 ) gt = (gt_dis > 0).to(TORCH_DTYPE) # if distance from the two sims are greater than 0, convert gt to torch_type loss = loss_fn(gt, delta_dis) loss.backward() optimizer.step() # Validate. n_correct = 0 n_total = 0 with torch.no_grad(): idx = np.arange(n_val) batch_ref = ref_val[idx].float() batch1 = x1_val[idx].float() batch2 = x2_val[idx].float() delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2) pred = delta_dis > 0 gt_dis = torch.nn.functional.cosine_similarity( ref_gt_val[idx], x2_gt_val[idx], dim=-1 ) - torch.nn.functional.cosine_similarity( ref_gt_val[idx], x1_gt_val[idx], dim=-1 ) gt = gt_dis > 0 n_correct += (pred == gt).sum().item() n_total += len(idx) acc = n_correct / n_total val_acc.append(acc) # Early stopping if val_acc does not improve for 10 epochs. if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]): break print( f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}" ) return model.to(TORCH_DTYPE), acc def compute_diversity_measures(clip_features, diversity_model): with torch.no_grad(): measures = diversity_model(clip_features).detach().cpu().numpy() return measures def tensor_to_list(tensor): sols = tensor.detach().cpu().numpy().astype(np.float32) return sols.reshape(sols.shape[0], -1) def list_to_tensor(list_): sols = np.array(list_).reshape( len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH ) # Hard-coded for now. return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE) def create_scheduler( sols, objs, clip_features, diversity_model, seed=None, ): measures = compute_diversity_measures(clip_features, diversity_model) archive_bounds = np.array( [np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)] ).T sols = tensor_to_list(sols) # Set up archive. archive = GridArchive( solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED ) # Add initial solutions to the archive. archive.add(sols, objs, measures) # Set up the GaussianEmitter. emitters = [ GaussianEmitter( archive=archive, sigma=0.1, initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"], batch_size=BATCH_SIZE, seed=SEED, ) ] # Return the archive and scheduler. return archive, Scheduler(archive, emitters) def plot_archive(archive): plt.figure(figsize=(6, 4.5)) grid_archive_heatmap(archive, vmin=0, vmax=100) plt.xlabel("Diversity Metric 1") plt.ylabel("Diversity Metric 2") return plt def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200): INIT_POP = init_pop TOTAL_ITRS = total_itrs # This tutorial uses ViT-B/32, you may use other checkpoints depending on your resources and need. CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE) CLIP_MODEL.eval() for p in CLIP_MODEL.parameters(): p.requires_grad_(False) def compute_clip_scores(imgs, text, return_clip_features=False): """Computes CLIP scores for a batch of images and a given text prompt.""" img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE) tokenized_text = clip.tokenize([text]).to(DEVICE) img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text) img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0] img_logits = 1 / img_logits * 100 # Remap the objective from minimizing [0, 10] to maximizing [0, 100] img_logits = (10.0 - img_logits) * 10.0 if return_clip_features: clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE) return img_logits, clip_features else: return img_logits DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim( pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE ) def evaluate_lsi( latents, prompt, return_features=False, diversity_model=None, ): """Evaluates the objective of LSI for a batch of latents and a given text prompt.""" images = SDPIPE( prompt, num_images_per_prompt=latents.shape[0], latents=latents, # num_inference_steps=1, # For testing. ).images objs, clip_features = compute_clip_scores( images, prompt, return_clip_features=True, ) images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE) dreamsim_features = DREAMSIM_MODEL.embed(images) if diversity_model is not None: measures = compute_diversity_measures(clip_features, diversity_model) else: measures = None if return_features: return objs, measures, clip_features, dreamsim_features else: return objs, measures update_schedule = [1, 21, 51, 101] # Iterations on which to update the archive. n_pref_data = 1000 # Number of preferences used in each update. archive = None best = 0.0 for itr in trange(1, TOTAL_ITRS + 1): # Update archive and scheduler if needed. if itr in update_schedule: if archive is None: tqdm.write("Initializing archive and diversity projection.") all_sols = [] all_clip_features = [] all_dreamsim_features = [] all_objs = [] # Sample random solutions and get judgment on similarity. n_batches = INIT_POP // BATCH_SIZE for _ in range(n_batches): sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE) objs, _, clip_features, dreamsim_features = evaluate_lsi( sols, prompt, return_features=True ) all_sols.append(sols) all_clip_features.append(clip_features) all_dreamsim_features.append(dreamsim_features) all_objs.append(objs) all_sols = torch.concat(all_sols, dim=0) all_clip_features = torch.concat(all_clip_features, dim=0) all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) all_objs = np.concatenate(all_objs, axis=0) # Initialize the diversity projection model. div_proj_data = [] div_proj_labels = [] for _ in range(n_pref_data): idx = np.random.choice(all_sols.shape[0], 3) div_proj_data.append(all_clip_features[idx]) div_proj_labels.append(all_dreamsim_features[idx]) div_proj_data = torch.concat(div_proj_data, dim=0) div_proj_labels = torch.concat(div_proj_labels, dim=0) div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1) div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1) diversity_model, div_proj_acc = fit_div_proj( div_proj_data, div_proj_label, latent_dim=2, ) else: tqdm.write("Updating archive and diversity projection.") # Get all the current solutions and collect feedback. all_sols = list_to_tensor(archive.data("solution")) n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int) all_clip_features = [] all_dreamsim_features = [] all_objs = [] for i in range(n_batches): sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] objs, _, clip_features, dreamsim_features = evaluate_lsi( sols, prompt, return_features=True ) all_clip_features.append(clip_features) all_dreamsim_features.append(dreamsim_features) all_objs.append(objs) all_clip_features = torch.concat( all_clip_features, dim=0 ) # n_pref_data * 3, dim all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0) all_objs = np.concatenate(all_objs, axis=0) # Update the diversity projection model. additional_features = [] additional_labels = [] for _ in range(n_pref_data): idx = np.random.choice(all_sols.shape[0], 3) additional_features.append(all_clip_features[idx]) additional_labels.append(all_dreamsim_features[idx]) additional_features = torch.concat(additional_features, dim=0) additional_labels = torch.concat(additional_labels, dim=0) additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1) additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1) div_proj_data = torch.concat( (div_proj_data, additional_div_proj_data), axis=0 ) div_proj_label = torch.concat( (div_proj_label, additional_div_proj_label), axis=0 ) diversity_model, div_proj_acc = fit_div_proj( div_proj_data, div_proj_label, latent_dim=2, ) archive, scheduler = create_scheduler( all_sols, all_objs, all_clip_features, diversity_model, seed=SEED, ) # Primary QD loop. sols = scheduler.ask() sols = list_to_tensor(sols) objs, measures, clip_features, dreamsim_features = evaluate_lsi( sols, prompt, return_features=True, diversity_model=diversity_model ) best = max(best, max(objs)) scheduler.tell(objs, measures) # This can be used as a flag to save on the final iteration, but note that # we do not save results in this tutorial. final_itr = itr == TOTAL_ITRS # Update the summary statistics for the archive. qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}") plt = plot_archive(archive) yield archive, plt plt = plot_archive(archive) return archive, plt def many_pictures(archive, prompt:str): # Modify this to determine how many images to plot along each dimension. img_freq = ( 4, # Number of columns of images. 4, # Number of rows of images. ) # List of images. imgs = [] # Convert archive to a df with solutions available. df = archive.data(return_type="pandas") # Compute the min and max measures for which solutions were found. measure_bounds = np.array( [ (df["measures_0"].min(), df["measures_0"].max()), (df["measures_1"].min(), df["measures_1"].max()), ] ) archive_bounds = np.array( [archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]] ) delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0] delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1] for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])): # Compute bounds of a box in measure space. measures_0_low = archive_bounds[0][0] + delta_measures_0 * row measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1) measures_1_low = archive_bounds[1][0] + delta_measures_1 * col measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1) if row == 0: measures_0_low = measure_bounds[0][0] if col == 0: measures_1_low = measure_bounds[1][0] if row == img_freq[0] - 1: measures_0_high = measure_bounds[0][1] if col == img_freq[1] - 1: measures_0_high = measure_bounds[1][1] # Query for a solution with measures within this box. query_string = ( f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & " f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}" ) df_box = df.query(query_string) if not df_box.empty: # Randomly sample a solution from the box. # Stable Diffusion solutions have SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH # dimensions, so the final solution col is solution_(x-1). sol = ( df_box.loc[ :, "solution_0" : "solution_{}".format( SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1 ), ] .sample(n=1) .iloc[0] ) # Convert the latent vector solution to an image. latents = torch.tensor(sol.to_numpy()).reshape( (1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH) ) latents = latents.to(TORCH_DTYPE).to(DEVICE) img = SDPIPE( prompt, num_images_per_prompt=1, latents=latents, # num_inference_steps=1, # For testing. ).images[0] img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0 imgs.append(img) else: imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH))) from torchvision.utils import make_grid def create_archive_tick_labels(measure_range, num_ticks): delta = (measure_range[1] - measure_range[0]) / num_ticks ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)] return ticklabels plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2)) img_grid = make_grid(imgs, nrow=img_freq[0], padding=0) img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0)) plt.imshow(img_grid) plt.xlabel("") num_x_ticks = img_freq[0] x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks) x_tick_range = img_grid.shape[1] x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks) plt.xticks(x_ticks, x_ticklabels) plt.ylabel("") num_y_ticks = img_freq[1] y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks) y_ticklabels.reverse() y_tick_range = img_grid.shape[0] y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks) plt.yticks(y_ticks, y_ticklabels) plt.tight_layout() return plt