init app
Browse files- .gitignore +1 -0
- app.py +69 -0
- generate_examples.py +46 -0
- qdhf_things.py +558 -0
- requirements.txt +12 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from qdhf_things import run_qdhf, many_pictures
|
3 |
+
from generate_examples import EXAMPLE_PROMPTS
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
|
7 |
+
# Get the absolute path to the examples directory
|
8 |
+
EXAMPLES_DIR = os.path.abspath("./examples")
|
9 |
+
|
10 |
+
def generate_images(prompt, init_pop, total_itrs):
|
11 |
+
init_pop = int(init_pop)
|
12 |
+
total_itrs = int(total_itrs)
|
13 |
+
|
14 |
+
# Use placeholder if prompt is empty
|
15 |
+
if not prompt.strip():
|
16 |
+
prompt = "a duck crossing the street"
|
17 |
+
|
18 |
+
archive_plots = []
|
19 |
+
for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs):
|
20 |
+
buf = io.BytesIO()
|
21 |
+
plt_fig.savefig(buf, format='png')
|
22 |
+
buf.seek(0)
|
23 |
+
archive_plots.append(buf.getvalue())
|
24 |
+
|
25 |
+
final_archive_plot = archive_plots[-1]
|
26 |
+
generated_images = many_pictures(archive, prompt)
|
27 |
+
|
28 |
+
# Save the final archive plot and generated images as temporary files
|
29 |
+
temp_archive_file = "temp_archive_plot.png"
|
30 |
+
temp_images_file = "temp_generated_images.png"
|
31 |
+
|
32 |
+
with open(temp_archive_file, 'wb') as f:
|
33 |
+
f.write(final_archive_plot)
|
34 |
+
|
35 |
+
generated_images.savefig(temp_images_file)
|
36 |
+
|
37 |
+
return temp_archive_file, temp_images_file
|
38 |
+
|
39 |
+
def show_example(example):
|
40 |
+
index = EXAMPLE_PROMPTS.index(example)
|
41 |
+
archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4")
|
42 |
+
images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png")
|
43 |
+
return archive_plot_path, images_path
|
44 |
+
|
45 |
+
with gr.Blocks() as demo:
|
46 |
+
gr.Markdown("# Quality Diversity through Human Feedback")
|
47 |
+
|
48 |
+
with gr.Row():
|
49 |
+
with gr.Column(scale=1):
|
50 |
+
prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street")
|
51 |
+
init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population")
|
52 |
+
total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations")
|
53 |
+
generate_button = gr.Button("Generate", variant="primary")
|
54 |
+
|
55 |
+
with gr.Column(scale=2):
|
56 |
+
archive_output = gr.Video(label="Archive Plot")
|
57 |
+
images_output = gr.Image(label="Generated Pictures")
|
58 |
+
|
59 |
+
generate_button.click(generate_images,
|
60 |
+
inputs=[prompt_input, init_pop, total_itrs],
|
61 |
+
outputs=[archive_output, images_output])
|
62 |
+
|
63 |
+
gr.Markdown("## Examples:")
|
64 |
+
for example in EXAMPLE_PROMPTS:
|
65 |
+
example_button = gr.Button(example)
|
66 |
+
example_button.click(show_example, inputs=example_button, outputs=[archive_output, images_output])
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
demo.launch()
|
generate_examples.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import imageio
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from qdhf_things import run_qdhf, many_pictures
|
6 |
+
|
7 |
+
EXAMPLE_PROMPTS = [
|
8 |
+
'an image of a cat on the sofa',
|
9 |
+
'an image of a bear in a national park',
|
10 |
+
'a photo of an astronaut riding a horse on mars',
|
11 |
+
'a drawing of a tree behind a fence',
|
12 |
+
'a painting of a sunset over the ocean',
|
13 |
+
'a sketch of a racoon sitting on a mushroom',
|
14 |
+
'a picture of a dragon flying over a castle',
|
15 |
+
'a photo of a robot playing the guitar',
|
16 |
+
]
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
print('Hello! I am a script!')
|
20 |
+
|
21 |
+
for i, example_prompt in enumerate(EXAMPLE_PROMPTS):
|
22 |
+
# Initialize list to store images for GIF
|
23 |
+
images = []
|
24 |
+
|
25 |
+
# Run QDHF
|
26 |
+
for archive, plt in run_qdhf(example_prompt):
|
27 |
+
# Save current plot to a temporary file
|
28 |
+
temp_filename = f'./examples/temp_plot_{i}.png'
|
29 |
+
plt.savefig(temp_filename)
|
30 |
+
plt.close()
|
31 |
+
|
32 |
+
# Read the saved image and append to images list
|
33 |
+
images.append(imageio.imread(temp_filename))
|
34 |
+
os.remove(temp_filename)
|
35 |
+
|
36 |
+
# Create a GIF from the images
|
37 |
+
gif_filename = f'./examples/archive_{i}.gif'
|
38 |
+
imageio.mimsave(gif_filename, images, duration=0.5) # Adjust duration as needed
|
39 |
+
|
40 |
+
# Save archive with pickle
|
41 |
+
pickle.dump(archive, open(f'./examples/archive_{i}.pkl', 'wb'))
|
42 |
+
|
43 |
+
# Save the final archive plot
|
44 |
+
plt = many_pictures(archive, example_prompt)
|
45 |
+
plt.savefig(f'./examples/archive_pics_{i}.png')
|
46 |
+
plt.close()
|
qdhf_things.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import pydantic
|
3 |
+
import time
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm, trange
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from diffusers import StableDiffusionPipeline
|
9 |
+
import clip
|
10 |
+
from dreamsim import dreamsim
|
11 |
+
from ribs.archives import GridArchive
|
12 |
+
from ribs.schedulers import Scheduler
|
13 |
+
from ribs.emitters import GaussianEmitter
|
14 |
+
import itertools
|
15 |
+
from ribs.visualize import grid_archive_heatmap
|
16 |
+
|
17 |
+
|
18 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
torch.cuda.empty_cache()
|
20 |
+
print("Torch device:", DEVICE)
|
21 |
+
|
22 |
+
# Use float16 for GPU, float32 for CPU.
|
23 |
+
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
24 |
+
print("Torch dtype:", TORCH_DTYPE)
|
25 |
+
IMG_WIDTH = 256
|
26 |
+
IMG_HEIGHT = 256
|
27 |
+
SD_IN_HEIGHT = 32
|
28 |
+
SD_IN_WIDTH = 32
|
29 |
+
SD_CHECKPOINT = "lambdalabs/miniSD-diffusers"
|
30 |
+
|
31 |
+
BATCH_SIZE = 4
|
32 |
+
SD_IN_CHANNELS = 4
|
33 |
+
SD_IN_SHAPE = (
|
34 |
+
BATCH_SIZE,
|
35 |
+
SD_IN_CHANNELS,
|
36 |
+
SD_IN_HEIGHT,
|
37 |
+
SD_IN_WIDTH,
|
38 |
+
)
|
39 |
+
|
40 |
+
SDPIPE = StableDiffusionPipeline.from_pretrained(
|
41 |
+
SD_CHECKPOINT,
|
42 |
+
torch_dtype=TORCH_DTYPE,
|
43 |
+
safety_checker=None, # For faster inference.
|
44 |
+
requires_safety_checker=False,
|
45 |
+
)
|
46 |
+
|
47 |
+
SDPIPE.set_progress_bar_config(disable=True)
|
48 |
+
SDPIPE = SDPIPE.to(DEVICE)
|
49 |
+
|
50 |
+
GRID_SIZE = (20, 20)
|
51 |
+
SEED = 123
|
52 |
+
np.random.seed(SEED)
|
53 |
+
torch.manual_seed(SEED)
|
54 |
+
|
55 |
+
# INIT_POP = 200 # Initial population.
|
56 |
+
# TOTAL_ITRS = 200 # Total number of iterations.
|
57 |
+
|
58 |
+
|
59 |
+
class DivProj(nn.Module):
|
60 |
+
def __init__(self, input_dim, latent_dim=2):
|
61 |
+
super().__init__()
|
62 |
+
self.proj = nn.Sequential(
|
63 |
+
nn.Linear(in_features=input_dim, out_features=latent_dim),
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
"""Get diversity representations."""
|
68 |
+
x = self.proj(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def calc_dis(self, x1, x2):
|
72 |
+
"""Calculate diversity distance as (squared) L2 distance."""
|
73 |
+
x1 = self.forward(x1)
|
74 |
+
x2 = self.forward(x2)
|
75 |
+
return torch.sum(torch.square(x1 - x2), -1)
|
76 |
+
|
77 |
+
def triplet_delta_dis(self, ref, x1, x2):
|
78 |
+
"""Calculate delta distance comparing x1 and x2 to ref."""
|
79 |
+
x1 = self.forward(x1)
|
80 |
+
x2 = self.forward(x2)
|
81 |
+
ref = self.forward(ref)
|
82 |
+
return (torch.sum(torch.square(ref - x1), -1) -
|
83 |
+
torch.sum(torch.square(ref - x2), -1))
|
84 |
+
|
85 |
+
|
86 |
+
# Triplet loss with margin 0.05.
|
87 |
+
# The binary preference labels are scaled to y = 1 or -1 for the loss, where y = 1 means x2 is more similar to ref than x1.
|
88 |
+
loss_fn = lambda y, delta_dis: torch.max(
|
89 |
+
torch.tensor([0.0]).to(DEVICE), 0.05 - (y * 2 - 1) * delta_dis
|
90 |
+
).mean()
|
91 |
+
|
92 |
+
|
93 |
+
def fit_div_proj(inputs, dreamsim_features, latent_dim, batch_size=32):
|
94 |
+
"""Trains the DivProj model on ground-truth labels."""
|
95 |
+
t = time.time()
|
96 |
+
model = DivProj(input_dim=inputs.shape[-1], latent_dim=latent_dim)
|
97 |
+
model.to(DEVICE)
|
98 |
+
|
99 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
100 |
+
|
101 |
+
n_pref_data = inputs.shape[0]
|
102 |
+
ref = inputs[:, 0]
|
103 |
+
x1 = inputs[:, 1]
|
104 |
+
x2 = inputs[:, 2]
|
105 |
+
|
106 |
+
n_train = int(n_pref_data * 0.75)
|
107 |
+
n_val = n_pref_data - n_train
|
108 |
+
|
109 |
+
# Split data into train and val.
|
110 |
+
ref_train = ref[:n_train]
|
111 |
+
x1_train = x1[:n_train]
|
112 |
+
x2_train = x2[:n_train]
|
113 |
+
ref_val = ref[n_train:]
|
114 |
+
x1_val = x1[n_train:]
|
115 |
+
x2_val = x2[n_train:]
|
116 |
+
|
117 |
+
# Split DreamSim features into train and val.
|
118 |
+
ref_dreamsim_features = dreamsim_features[:, 0]
|
119 |
+
x1_dreamsim_features = dreamsim_features[:, 1]
|
120 |
+
x2_dreamsim_features = dreamsim_features[:, 2]
|
121 |
+
ref_gt_train = ref_dreamsim_features[:n_train]
|
122 |
+
x1_gt_train = x1_dreamsim_features[:n_train]
|
123 |
+
x2_gt_train = x2_dreamsim_features[:n_train]
|
124 |
+
ref_gt_val = ref_dreamsim_features[n_train:]
|
125 |
+
x1_gt_val = x1_dreamsim_features[n_train:]
|
126 |
+
x2_gt_val = x2_dreamsim_features[n_train:]
|
127 |
+
|
128 |
+
val_acc = []
|
129 |
+
n_iters_per_epoch = max((n_train) // batch_size, 1)
|
130 |
+
for epoch in range(200):
|
131 |
+
for _ in range(n_iters_per_epoch):
|
132 |
+
optimizer.zero_grad()
|
133 |
+
|
134 |
+
idx = np.random.choice(n_train, batch_size)
|
135 |
+
batch_ref = ref_train[idx].float()
|
136 |
+
batch1 = x1_train[idx].float()
|
137 |
+
batch2 = x2_train[idx].float()
|
138 |
+
|
139 |
+
# Get delta distance from model.
|
140 |
+
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
|
141 |
+
|
142 |
+
# Get preference labels from DreamSim features.
|
143 |
+
gt_dis = torch.nn.functional.cosine_similarity(
|
144 |
+
ref_gt_train[idx], x2_gt_train[idx], dim=-1
|
145 |
+
) - torch.nn.functional.cosine_similarity(
|
146 |
+
ref_gt_train[idx], x1_gt_train[idx], dim=-1
|
147 |
+
)
|
148 |
+
gt = (gt_dis > 0).to(TORCH_DTYPE) # if distance from the two sims are greater than 0, convert gt to torch_type
|
149 |
+
|
150 |
+
loss = loss_fn(gt, delta_dis)
|
151 |
+
loss.backward()
|
152 |
+
optimizer.step()
|
153 |
+
|
154 |
+
# Validate.
|
155 |
+
n_correct = 0
|
156 |
+
n_total = 0
|
157 |
+
with torch.no_grad():
|
158 |
+
idx = np.arange(n_val)
|
159 |
+
batch_ref = ref_val[idx].float()
|
160 |
+
batch1 = x1_val[idx].float()
|
161 |
+
batch2 = x2_val[idx].float()
|
162 |
+
delta_dis = model.triplet_delta_dis(batch_ref, batch1, batch2)
|
163 |
+
pred = delta_dis > 0
|
164 |
+
gt_dis = torch.nn.functional.cosine_similarity(
|
165 |
+
ref_gt_val[idx], x2_gt_val[idx], dim=-1
|
166 |
+
) - torch.nn.functional.cosine_similarity(
|
167 |
+
ref_gt_val[idx], x1_gt_val[idx], dim=-1
|
168 |
+
)
|
169 |
+
gt = gt_dis > 0
|
170 |
+
n_correct += (pred == gt).sum().item()
|
171 |
+
n_total += len(idx)
|
172 |
+
|
173 |
+
acc = n_correct / n_total
|
174 |
+
val_acc.append(acc)
|
175 |
+
|
176 |
+
# Early stopping if val_acc does not improve for 10 epochs.
|
177 |
+
if epoch > 10 and np.mean(val_acc[-10:]) < np.mean(val_acc[-11:-1]):
|
178 |
+
break
|
179 |
+
|
180 |
+
print(
|
181 |
+
f"{np.round(time.time()- t, 1)}s ({epoch+1} epochs) | DivProj (n={n_pref_data}) fitted with val acc.: {acc}"
|
182 |
+
)
|
183 |
+
|
184 |
+
return model.to(TORCH_DTYPE), acc
|
185 |
+
|
186 |
+
|
187 |
+
def compute_diversity_measures(clip_features, diversity_model):
|
188 |
+
with torch.no_grad():
|
189 |
+
measures = diversity_model(clip_features).detach().cpu().numpy()
|
190 |
+
return measures
|
191 |
+
|
192 |
+
|
193 |
+
def tensor_to_list(tensor):
|
194 |
+
sols = tensor.detach().cpu().numpy().astype(np.float32)
|
195 |
+
return sols.reshape(sols.shape[0], -1)
|
196 |
+
|
197 |
+
|
198 |
+
def list_to_tensor(list_):
|
199 |
+
sols = np.array(list_).reshape(
|
200 |
+
len(list_), 4, SD_IN_HEIGHT, SD_IN_WIDTH
|
201 |
+
) # Hard-coded for now.
|
202 |
+
return torch.tensor(sols, dtype=TORCH_DTYPE, device=DEVICE)
|
203 |
+
|
204 |
+
|
205 |
+
def create_scheduler(
|
206 |
+
sols,
|
207 |
+
objs,
|
208 |
+
clip_features,
|
209 |
+
diversity_model,
|
210 |
+
seed=None,
|
211 |
+
):
|
212 |
+
measures = compute_diversity_measures(clip_features, diversity_model)
|
213 |
+
archive_bounds = np.array(
|
214 |
+
[np.quantile(measures, 0.01, axis=0), np.quantile(measures, 0.99, axis=0)]
|
215 |
+
).T
|
216 |
+
|
217 |
+
sols = tensor_to_list(sols)
|
218 |
+
|
219 |
+
# Set up archive.
|
220 |
+
archive = GridArchive(
|
221 |
+
solution_dim=len(sols[0]), dims=GRID_SIZE, ranges=archive_bounds, seed=SEED
|
222 |
+
)
|
223 |
+
|
224 |
+
# Add initial solutions to the archive.
|
225 |
+
archive.add(sols, objs, measures)
|
226 |
+
|
227 |
+
# Set up the GaussianEmitter.
|
228 |
+
emitters = [
|
229 |
+
GaussianEmitter(
|
230 |
+
archive=archive,
|
231 |
+
sigma=0.1,
|
232 |
+
initial_solutions=archive.sample_elites(BATCH_SIZE)["solution"],
|
233 |
+
batch_size=BATCH_SIZE,
|
234 |
+
seed=SEED,
|
235 |
+
)
|
236 |
+
]
|
237 |
+
|
238 |
+
# Return the archive and scheduler.
|
239 |
+
return archive, Scheduler(archive, emitters)
|
240 |
+
|
241 |
+
|
242 |
+
def plot_archive(archive):
|
243 |
+
plt.figure(figsize=(6, 4.5))
|
244 |
+
grid_archive_heatmap(archive, vmin=0, vmax=100)
|
245 |
+
plt.xlabel("Diversity Metric 1")
|
246 |
+
plt.ylabel("Diversity Metric 2")
|
247 |
+
return plt
|
248 |
+
|
249 |
+
|
250 |
+
def run_qdhf(prompt:str, init_pop: int=200, total_itrs: int=200):
|
251 |
+
INIT_POP = init_pop
|
252 |
+
TOTAL_ITRS = total_itrs
|
253 |
+
|
254 |
+
# This tutorial uses ViT-B/32, you may use other checkpoints depending on your resources and need.
|
255 |
+
CLIP_MODEL, CLIP_PREPROCESS = clip.load("ViT-B/32", device=DEVICE)
|
256 |
+
CLIP_MODEL.eval()
|
257 |
+
for p in CLIP_MODEL.parameters():
|
258 |
+
p.requires_grad_(False)
|
259 |
+
|
260 |
+
def compute_clip_scores(imgs, text, return_clip_features=False):
|
261 |
+
"""Computes CLIP scores for a batch of images and a given text prompt."""
|
262 |
+
img_tensor = torch.stack([CLIP_PREPROCESS(img) for img in imgs]).to(DEVICE)
|
263 |
+
tokenized_text = clip.tokenize([text]).to(DEVICE)
|
264 |
+
img_logits, _text_logits = CLIP_MODEL(img_tensor, tokenized_text)
|
265 |
+
img_logits = img_logits.detach().cpu().numpy().astype(np.float32)[:, 0]
|
266 |
+
img_logits = 1 / img_logits * 100
|
267 |
+
# Remap the objective from minimizing [0, 10] to maximizing [0, 100]
|
268 |
+
img_logits = (10.0 - img_logits) * 10.0
|
269 |
+
|
270 |
+
if return_clip_features:
|
271 |
+
clip_features = CLIP_MODEL.encode_image(img_tensor).to(TORCH_DTYPE)
|
272 |
+
return img_logits, clip_features
|
273 |
+
else:
|
274 |
+
return img_logits
|
275 |
+
|
276 |
+
DREAMSIM_MODEL, DREAMSIM_PREPROCESS = dreamsim(
|
277 |
+
pretrained=True, dreamsim_type="open_clip_vitb32", device=DEVICE
|
278 |
+
)
|
279 |
+
|
280 |
+
def evaluate_lsi(
|
281 |
+
latents,
|
282 |
+
prompt,
|
283 |
+
return_features=False,
|
284 |
+
diversity_model=None,
|
285 |
+
):
|
286 |
+
"""Evaluates the objective of LSI for a batch of latents and a given text prompt."""
|
287 |
+
|
288 |
+
images = SDPIPE(
|
289 |
+
prompt,
|
290 |
+
num_images_per_prompt=latents.shape[0],
|
291 |
+
latents=latents,
|
292 |
+
# num_inference_steps=1, # For testing.
|
293 |
+
).images
|
294 |
+
|
295 |
+
objs, clip_features = compute_clip_scores(
|
296 |
+
images,
|
297 |
+
prompt,
|
298 |
+
return_clip_features=True,
|
299 |
+
)
|
300 |
+
|
301 |
+
images = torch.cat([DREAMSIM_PREPROCESS(img) for img in images]).to(DEVICE)
|
302 |
+
dreamsim_features = DREAMSIM_MODEL.embed(images)
|
303 |
+
|
304 |
+
if diversity_model is not None:
|
305 |
+
measures = compute_diversity_measures(clip_features, diversity_model)
|
306 |
+
else:
|
307 |
+
measures = None
|
308 |
+
|
309 |
+
if return_features:
|
310 |
+
return objs, measures, clip_features, dreamsim_features
|
311 |
+
else:
|
312 |
+
return objs, measures
|
313 |
+
|
314 |
+
|
315 |
+
update_schedule = [1, 21, 51, 101] # Iterations on which to update the archive.
|
316 |
+
n_pref_data = 1000 # Number of preferences used in each update.
|
317 |
+
|
318 |
+
archive = None
|
319 |
+
|
320 |
+
best = 0.0
|
321 |
+
for itr in trange(1, TOTAL_ITRS + 1):
|
322 |
+
# Update archive and scheduler if needed.
|
323 |
+
if itr in update_schedule:
|
324 |
+
if archive is None:
|
325 |
+
tqdm.write("Initializing archive and diversity projection.")
|
326 |
+
|
327 |
+
all_sols = []
|
328 |
+
all_clip_features = []
|
329 |
+
all_dreamsim_features = []
|
330 |
+
all_objs = []
|
331 |
+
|
332 |
+
# Sample random solutions and get judgment on similarity.
|
333 |
+
n_batches = INIT_POP // BATCH_SIZE
|
334 |
+
for _ in range(n_batches):
|
335 |
+
sols = torch.randn(SD_IN_SHAPE, device=DEVICE, dtype=TORCH_DTYPE)
|
336 |
+
objs, _, clip_features, dreamsim_features = evaluate_lsi(
|
337 |
+
sols, prompt, return_features=True
|
338 |
+
)
|
339 |
+
all_sols.append(sols)
|
340 |
+
all_clip_features.append(clip_features)
|
341 |
+
all_dreamsim_features.append(dreamsim_features)
|
342 |
+
all_objs.append(objs)
|
343 |
+
all_sols = torch.concat(all_sols, dim=0)
|
344 |
+
all_clip_features = torch.concat(all_clip_features, dim=0)
|
345 |
+
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
|
346 |
+
all_objs = np.concatenate(all_objs, axis=0)
|
347 |
+
|
348 |
+
# Initialize the diversity projection model.
|
349 |
+
div_proj_data = []
|
350 |
+
div_proj_labels = []
|
351 |
+
for _ in range(n_pref_data):
|
352 |
+
idx = np.random.choice(all_sols.shape[0], 3)
|
353 |
+
div_proj_data.append(all_clip_features[idx])
|
354 |
+
div_proj_labels.append(all_dreamsim_features[idx])
|
355 |
+
div_proj_data = torch.concat(div_proj_data, dim=0)
|
356 |
+
div_proj_labels = torch.concat(div_proj_labels, dim=0)
|
357 |
+
div_proj_data = div_proj_data.reshape(n_pref_data, 3, -1)
|
358 |
+
div_proj_label = div_proj_labels.reshape(n_pref_data, 3, -1)
|
359 |
+
diversity_model, div_proj_acc = fit_div_proj(
|
360 |
+
div_proj_data,
|
361 |
+
div_proj_label,
|
362 |
+
latent_dim=2,
|
363 |
+
)
|
364 |
+
|
365 |
+
else:
|
366 |
+
tqdm.write("Updating archive and diversity projection.")
|
367 |
+
|
368 |
+
# Get all the current solutions and collect feedback.
|
369 |
+
all_sols = list_to_tensor(archive.data("solution"))
|
370 |
+
n_batches = np.ceil(len(all_sols) / BATCH_SIZE).astype(int)
|
371 |
+
all_clip_features = []
|
372 |
+
all_dreamsim_features = []
|
373 |
+
all_objs = []
|
374 |
+
for i in range(n_batches):
|
375 |
+
sols = all_sols[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]
|
376 |
+
objs, _, clip_features, dreamsim_features = evaluate_lsi(
|
377 |
+
sols, prompt, return_features=True
|
378 |
+
)
|
379 |
+
all_clip_features.append(clip_features)
|
380 |
+
all_dreamsim_features.append(dreamsim_features)
|
381 |
+
all_objs.append(objs)
|
382 |
+
all_clip_features = torch.concat(
|
383 |
+
all_clip_features, dim=0
|
384 |
+
) # n_pref_data * 3, dim
|
385 |
+
all_dreamsim_features = torch.concat(all_dreamsim_features, dim=0)
|
386 |
+
all_objs = np.concatenate(all_objs, axis=0)
|
387 |
+
|
388 |
+
# Update the diversity projection model.
|
389 |
+
additional_features = []
|
390 |
+
additional_labels = []
|
391 |
+
for _ in range(n_pref_data):
|
392 |
+
idx = np.random.choice(all_sols.shape[0], 3)
|
393 |
+
additional_features.append(all_clip_features[idx])
|
394 |
+
additional_labels.append(all_dreamsim_features[idx])
|
395 |
+
additional_features = torch.concat(additional_features, dim=0)
|
396 |
+
additional_labels = torch.concat(additional_labels, dim=0)
|
397 |
+
additional_div_proj_data = additional_features.reshape(n_pref_data, 3, -1)
|
398 |
+
additional_div_proj_label = additional_labels.reshape(n_pref_data, 3, -1)
|
399 |
+
div_proj_data = torch.concat(
|
400 |
+
(div_proj_data, additional_div_proj_data), axis=0
|
401 |
+
)
|
402 |
+
div_proj_label = torch.concat(
|
403 |
+
(div_proj_label, additional_div_proj_label), axis=0
|
404 |
+
)
|
405 |
+
diversity_model, div_proj_acc = fit_div_proj(
|
406 |
+
div_proj_data,
|
407 |
+
div_proj_label,
|
408 |
+
latent_dim=2,
|
409 |
+
)
|
410 |
+
|
411 |
+
archive, scheduler = create_scheduler(
|
412 |
+
all_sols,
|
413 |
+
all_objs,
|
414 |
+
all_clip_features,
|
415 |
+
diversity_model,
|
416 |
+
seed=SEED,
|
417 |
+
)
|
418 |
+
|
419 |
+
# Primary QD loop.
|
420 |
+
sols = scheduler.ask()
|
421 |
+
sols = list_to_tensor(sols)
|
422 |
+
objs, measures, clip_features, dreamsim_features = evaluate_lsi(
|
423 |
+
sols, prompt, return_features=True, diversity_model=diversity_model
|
424 |
+
)
|
425 |
+
best = max(best, max(objs))
|
426 |
+
scheduler.tell(objs, measures)
|
427 |
+
|
428 |
+
# This can be used as a flag to save on the final iteration, but note that
|
429 |
+
# we do not save results in this tutorial.
|
430 |
+
final_itr = itr == TOTAL_ITRS
|
431 |
+
|
432 |
+
# Update the summary statistics for the archive.
|
433 |
+
qd_score, coverage = archive.stats.norm_qd_score, archive.stats.coverage
|
434 |
+
|
435 |
+
tqdm.write(f"QD score: {np.round(qd_score, 2)} Coverage: {coverage * 100}")
|
436 |
+
|
437 |
+
plt = plot_archive(archive)
|
438 |
+
yield archive, plt
|
439 |
+
|
440 |
+
plt = plot_archive(archive)
|
441 |
+
return archive, plt
|
442 |
+
|
443 |
+
|
444 |
+
def many_pictures(archive, prompt:str):
|
445 |
+
# Modify this to determine how many images to plot along each dimension.
|
446 |
+
img_freq = (
|
447 |
+
4, # Number of columns of images.
|
448 |
+
4, # Number of rows of images.
|
449 |
+
)
|
450 |
+
|
451 |
+
# List of images.
|
452 |
+
imgs = []
|
453 |
+
|
454 |
+
# Convert archive to a df with solutions available.
|
455 |
+
df = archive.data(return_type="pandas")
|
456 |
+
|
457 |
+
# Compute the min and max measures for which solutions were found.
|
458 |
+
measure_bounds = np.array(
|
459 |
+
[
|
460 |
+
(df["measures_0"].min(), df["measures_0"].max()),
|
461 |
+
(df["measures_1"].min(), df["measures_1"].max()),
|
462 |
+
]
|
463 |
+
)
|
464 |
+
|
465 |
+
archive_bounds = np.array(
|
466 |
+
[archive.boundaries[0][[0, -1]], archive.boundaries[1][[0, -1]]]
|
467 |
+
)
|
468 |
+
|
469 |
+
|
470 |
+
delta_measures_0 = (archive_bounds[0][1] - archive_bounds[0][0]) / img_freq[0]
|
471 |
+
delta_measures_1 = (archive_bounds[1][1] - archive_bounds[1][0]) / img_freq[1]
|
472 |
+
|
473 |
+
|
474 |
+
for col, row in itertools.product(range(img_freq[1]), range(img_freq[0])):
|
475 |
+
# Compute bounds of a box in measure space.
|
476 |
+
measures_0_low = archive_bounds[0][0] + delta_measures_0 * row
|
477 |
+
measures_0_high = archive_bounds[0][0] + delta_measures_0 * (row + 1)
|
478 |
+
measures_1_low = archive_bounds[1][0] + delta_measures_1 * col
|
479 |
+
measures_1_high = archive_bounds[1][0] + delta_measures_1 * (col + 1)
|
480 |
+
|
481 |
+
if row == 0:
|
482 |
+
measures_0_low = measure_bounds[0][0]
|
483 |
+
if col == 0:
|
484 |
+
measures_1_low = measure_bounds[1][0]
|
485 |
+
if row == img_freq[0] - 1:
|
486 |
+
measures_0_high = measure_bounds[0][1]
|
487 |
+
if col == img_freq[1] - 1:
|
488 |
+
measures_0_high = measure_bounds[1][1]
|
489 |
+
|
490 |
+
# Query for a solution with measures within this box.
|
491 |
+
query_string = (
|
492 |
+
f"{measures_0_low} <= measures_0 & measures_0 <= {measures_0_high} & "
|
493 |
+
f"{measures_1_low} <= measures_1 & measures_1 <= {measures_1_high}"
|
494 |
+
)
|
495 |
+
df_box = df.query(query_string)
|
496 |
+
|
497 |
+
if not df_box.empty:
|
498 |
+
# Randomly sample a solution from the box.
|
499 |
+
# Stable Diffusion solutions have SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH
|
500 |
+
# dimensions, so the final solution col is solution_(x-1).
|
501 |
+
sol = (
|
502 |
+
df_box.loc[
|
503 |
+
:,
|
504 |
+
"solution_0" : "solution_{}".format(
|
505 |
+
SD_IN_CHANNELS * SD_IN_HEIGHT * SD_IN_WIDTH - 1
|
506 |
+
),
|
507 |
+
]
|
508 |
+
.sample(n=1)
|
509 |
+
.iloc[0]
|
510 |
+
)
|
511 |
+
|
512 |
+
# Convert the latent vector solution to an image.
|
513 |
+
latents = torch.tensor(sol.to_numpy()).reshape(
|
514 |
+
(1, SD_IN_CHANNELS, SD_IN_HEIGHT, SD_IN_WIDTH)
|
515 |
+
)
|
516 |
+
latents = latents.to(TORCH_DTYPE).to(DEVICE)
|
517 |
+
img = SDPIPE(
|
518 |
+
prompt,
|
519 |
+
num_images_per_prompt=1,
|
520 |
+
latents=latents,
|
521 |
+
# num_inference_steps=1, # For testing.
|
522 |
+
).images[0]
|
523 |
+
|
524 |
+
img = torch.from_numpy(np.array(img)).permute(2, 0, 1) / 255.0
|
525 |
+
imgs.append(img)
|
526 |
+
else:
|
527 |
+
imgs.append(torch.zeros((3, IMG_HEIGHT, IMG_WIDTH)))
|
528 |
+
from torchvision.utils import make_grid
|
529 |
+
|
530 |
+
|
531 |
+
def create_archive_tick_labels(measure_range, num_ticks):
|
532 |
+
delta = (measure_range[1] - measure_range[0]) / num_ticks
|
533 |
+
ticklabels = [round(delta * p + measure_range[0], 3) for p in range(num_ticks + 1)]
|
534 |
+
return ticklabels
|
535 |
+
|
536 |
+
|
537 |
+
plt.figure(figsize=(img_freq[0] * 2, img_freq[0] * 2))
|
538 |
+
img_grid = make_grid(imgs, nrow=img_freq[0], padding=0)
|
539 |
+
img_grid = np.transpose(img_grid.cpu().numpy(), (1, 2, 0))
|
540 |
+
plt.imshow(img_grid)
|
541 |
+
|
542 |
+
plt.xlabel("")
|
543 |
+
num_x_ticks = img_freq[0]
|
544 |
+
x_ticklabels = create_archive_tick_labels(measure_bounds[0], num_x_ticks)
|
545 |
+
x_tick_range = img_grid.shape[1]
|
546 |
+
x_ticks = np.arange(0, x_tick_range + 1e-9, step=x_tick_range / num_x_ticks)
|
547 |
+
plt.xticks(x_ticks, x_ticklabels)
|
548 |
+
|
549 |
+
plt.ylabel("")
|
550 |
+
num_y_ticks = img_freq[1]
|
551 |
+
y_ticklabels = create_archive_tick_labels(measure_bounds[1], num_y_ticks)
|
552 |
+
y_ticklabels.reverse()
|
553 |
+
y_tick_range = img_grid.shape[0]
|
554 |
+
y_ticks = np.arange(0, y_tick_range + 1e-9, step=y_tick_range / num_y_ticks)
|
555 |
+
plt.yticks(y_ticks, y_ticklabels)
|
556 |
+
plt.tight_layout()
|
557 |
+
|
558 |
+
return plt
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
matplotlib
|
3 |
+
imageio
|
4 |
+
pydantic
|
5 |
+
torch
|
6 |
+
diffusers
|
7 |
+
dreamsim
|
8 |
+
ribs
|
9 |
+
ftfy
|
10 |
+
regex
|
11 |
+
tqdm
|
12 |
+
git+https://github.com/openai/CLIP.git
|