from diffusers import DiffusionPipeline |
from diffusers import DDPMPipeline |
from diffusers import DDPMScheduler, UNet2DConditionModel |
import torch |
import torchvision.transforms as T |
from PIL import Image |
from transformers import AutoTokenizer |
from datasets import load_dataset |
import numpy as np |
import pandas as pd |
from tqdm.auto import tqdm |
class RCTDiffusionPipeline(DiffusionPipeline): |
def __init__(self): |
super().__init__() |
self.object_description_dict = {} |
self.color1_dict = {} |
self.color2_dict = {} |
self.color3_dict = {} |
self.load_dictionaries_from_dataset() |
self.scheduler = DDPMScheduler() |
hidden_dim = self.get_class_labels_size() |
self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \ |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\ |
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160, |
block_out_channels=(64, 128, 256), norm_num_groups=32) |
self.unet.to(device='cuda', dtype=torch.float16) |
def load_dictionaries_from_dataset(self): |
dataset = load_dataset('frutiemax/rct_dataset') |
dataset = dataset['train'] |
for row in dataset: |
if not row['object_description'] in self.object_description_dict: |
self.object_description_dict[row['object_description']] = len(self.object_description_dict) |
if not row['color1'] in self.color1_dict and row['color1'] != 'none': |
self.color1_dict[row['color1']] = len(self.color1_dict) |
if not row['color2'] in self.color2_dict and row['color2'] != 'none': |
self.color2_dict[row['color2']] = len(self.color2_dict) |
if not row['color3'] in self.color3_dict and row['color3'] != 'none': |
self.color3_dict[row['color3']] = len(self.color3_dict) |
def print_class_tokens_to_csv(self): |
object_descriptions = pd.DataFrame(self.object_description_dict.items()) |
object_descriptions.to_csv('object_descriptions_tokens.csv') |
color1 = pd.DataFrame(self.color1_dict.items()) |
color1.to_csv('color1_tokens.csv') |
color2 = pd.DataFrame(self.color2_dict.items()) |
color2.to_csv('color2_tokens.csv') |
color3 = pd.DataFrame(self.color3_dict.items()) |
color3.to_csv('color3_tokens.csv') |
def get_object_description_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
result = np.zeros(len(self.object_description_dict.items())) |
for classifier in classifiers: |
id, weight = classifier |
if id in self.object_description_dict: |
weight_index = self.object_description_dict[id] |
result[weight_index] = weight |
return result |
def get_color1_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
result = np.zeros(len(self.color1_dict.items())) |
for classifier in classifiers: |
id, weight = classifier |
if id in self.color1_dict: |
weight_index = self.color1_dict[id] |
result[weight_index] = weight |
return result |
def get_color2_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
result = np.zeros(len(self.color2_dict.items())) |
for classifier in classifiers: |
id, weight = classifier |
if id in self.color2_dict: |
weight_index = self.color2_dict[id] |
result[weight_index] = weight |
return result |
def get_color3_weights(self, classifiers : list[tuple[str, float]]) -> np.array: |
result = np.zeros(len(self.color3_dict.items())) |
for classifier in classifiers: |
id, weight = classifier |
if id in self.color3_dict: |
weight_index = self.color3_dict[id] |
result[weight_index] = weight |
return result |
def get_class_labels_size(self): |
return len(self.object_description_dict.items()) + len(self.color1_dict.items()) + len(self.color2_dict.items()) + len(self.color3_dict.items()) |
def pack_labels_to_tensor(self, num_images, object_descriptions : np.array, colors1: np.array, colors2 : np.array, colors3 : np.array) -> torch.Tensor: |
num_labels = self.get_class_labels_size() |
class_labels = torch.Tensor(size=(num_images, num_labels)) |
for batch_index in range(num_images): |
offset = 0 |
class_labels[batch_index, offset:offset + len(self.object_description_dict)] = torch.from_numpy(object_descriptions[batch_index]) |
offset += len(self.object_description_dict.items()) |
class_labels[batch_index, offset:offset + len(self.color1_dict)] = torch.from_numpy(colors1[batch_index]) |
offset += len(self.color1_dict.items()) |
class_labels[batch_index, offset:offset + len(self.color2_dict)] = torch.from_numpy(colors2[batch_index]) |
offset += len(self.color2_dict.items()) |
class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index]) |
class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size())) |
return class_labels |
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \ |
color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \ |
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())): |
if len(object_description) != batch_size: |
return None |
if len(color1) != batch_size: |
return None |
if color2 != None and len(color2) != batch_size: |
return None |
if color3 != None and len(color3) != batch_size: |
return None |
object_descriptions = [] |
colors1 = [] |
colors2 = [] |
colors3 = [] |
for batch_index in range(batch_size): |
obj_desc = self.get_object_description_weights(object_description[batch_index]) |
c1 = self.get_color1_weights(color1[batch_index]) |
if color2 != None: |
c2 = self.get_color2_weights(color2[batch_index]) |
else: |
c2 = self.get_color2_weights([]) |
if color3 != None: |
c3 = self.get_color3_weights(color3[batch_index]) |
else: |
c3 = self.get_color3_weights([]) |
object_descriptions.append(obj_desc) |
colors1.append(c1) |
colors2.append(c2) |
colors3.append(c3) |
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16) |
self.scheduler.set_timesteps(num_inference_steps) |
noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to(dtype=torch.float16, device='cuda') |
for batch_index in range(batch_size): |
for view_index in range(4): |
noise = torch.randn(3, 256, 256).to(dtype=torch.float16, device='cuda') |
noise_batches[batch_index, view_index] = noise |
noise_batches = torch.reshape(noise_batches, (batch_size, 1, 12, 256, 256)).to(dtype=torch.float16, device='cuda') |
progress_bar = tqdm(total=num_inference_steps) |
epoch = 0 |
for t in self.scheduler.timesteps: |
progress_bar.set_description(f'Inference step {epoch}') |
for batch_index in range(batch_size): |
with torch.no_grad(): |
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=class_labels).sample |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample |
noise_batches[batch_index] = previous_noisy_sample |
progress_bar.update(1) |
epoch = epoch + 1 |
noise_batches = torch.reshape(noise_batches, (batch_size, 4, 3, 256, 256)).to('cpu') |
output_images = [] |
tensor_to_pil = T.ToPILImage('RGB') |
for batch_index in range(batch_size): |
for image_index in range(4): |
output_images.append(tensor_to_pil(noise_batches[batch_index, image_index])) |
return output_images |