Training code
Browse files- rct_diffusion_pipeline.py +198 -8
- test_pipeline.py +6 -5
- train_model.py +128 -0
rct_diffusion_pipeline.py
CHANGED
@@ -1,15 +1,205 @@
|
|
1 |
from diffusers import DiffusionPipeline
|
|
|
|
|
2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
5 |
-
def __init__(self
|
6 |
super().__init__()
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
def
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
1 |
from diffusers import DiffusionPipeline
|
2 |
+
from diffusers import DDPMPipeline
|
3 |
+
from diffusers import DDPMScheduler, UNet2DModel
|
4 |
import torch
|
5 |
+
import torchvision.transforms as T
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
from datasets import load_dataset
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from tqdm.auto import tqdm
|
12 |
|
13 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
14 |
+
def __init__(self):
|
15 |
super().__init__()
|
16 |
+
|
17 |
+
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
18 |
+
self.object_description_dict = {}
|
19 |
+
self.color1_dict = {}
|
20 |
+
self.color2_dict = {}
|
21 |
+
self.color3_dict = {}
|
22 |
+
self.load_dictionaries_from_dataset()
|
23 |
+
|
24 |
+
self.scheduler = DDPMScheduler()
|
25 |
+
|
26 |
+
# the number of hidden features is dependant on the loaded dictionaries!
|
27 |
+
self.unet = UNet2DModel(sample_size=256, in_channels=12, out_channels=12, \
|
28 |
+
down_block_types=('DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D'), up_block_types=('UpBlock2D', 'UpBlock2D', 'AttnUpBlock2D'), \
|
29 |
+
block_out_channels=(16, 32, 64), norm_num_groups=16)
|
30 |
+
|
31 |
+
self.unet.to('cuda')
|
32 |
|
33 |
+
def load_dictionaries_from_dataset(self):
|
34 |
+
dataset = load_dataset('frutiemax/rct_dataset')
|
35 |
+
dataset = dataset['train']
|
36 |
+
|
37 |
+
for row in dataset:
|
38 |
+
if not row['object_description'] in self.object_description_dict:
|
39 |
+
self.object_description_dict[row['object_description']] = len(self.object_description_dict)
|
40 |
+
if not row['color1'] in self.color1_dict and row['color1'] != 'none':
|
41 |
+
self.color1_dict[row['color1']] = len(self.color1_dict)
|
42 |
+
if not row['color2'] in self.color2_dict and row['color2'] != 'none':
|
43 |
+
self.color2_dict[row['color2']] = len(self.color2_dict)
|
44 |
+
if not row['color3'] in self.color3_dict and row['color3'] != 'none':
|
45 |
+
self.color3_dict[row['color3']] = len(self.color3_dict)
|
46 |
+
|
47 |
+
# helper functions to know the classes
|
48 |
+
def print_class_tokens_to_csv(self):
|
49 |
+
object_descriptions = pd.DataFrame(self.object_description_dict.items())
|
50 |
+
object_descriptions.to_csv('object_descriptions_tokens.csv')
|
51 |
+
|
52 |
+
color1 = pd.DataFrame(self.color1_dict.items())
|
53 |
+
color1.to_csv('color1_tokens.csv')
|
54 |
+
|
55 |
+
color2 = pd.DataFrame(self.color2_dict.items())
|
56 |
+
color2.to_csv('color2_tokens.csv')
|
57 |
+
|
58 |
+
color3 = pd.DataFrame(self.color3_dict.items())
|
59 |
+
color3.to_csv('color3_tokens.csv')
|
60 |
+
|
61 |
+
# helper functions to build weight tables
|
62 |
+
def get_object_description_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
|
63 |
+
result = np.zeros(len(self.object_description_dict.items()))
|
64 |
+
|
65 |
+
for classifier in classifiers:
|
66 |
+
id, weight = classifier
|
67 |
+
if id in self.object_description_dict:
|
68 |
+
weight_index = self.object_description_dict[id]
|
69 |
+
result[weight_index] = weight
|
70 |
+
return result
|
71 |
+
|
72 |
+
def get_color1_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
|
73 |
+
result = np.zeros(len(self.color1_dict.items()))
|
74 |
+
|
75 |
+
for classifier in classifiers:
|
76 |
+
id, weight = classifier
|
77 |
+
if id in self.color1_dict:
|
78 |
+
weight_index = self.color1_dict[id]
|
79 |
+
result[weight_index] = weight
|
80 |
+
return result
|
81 |
+
|
82 |
+
def get_color2_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
|
83 |
+
result = np.zeros(len(self.color2_dict.items()))
|
84 |
+
|
85 |
+
for classifier in classifiers:
|
86 |
+
id, weight = classifier
|
87 |
+
if id in self.color2_dict:
|
88 |
+
weight_index = self.color2_dict[id]
|
89 |
+
result[weight_index] = weight
|
90 |
+
return result
|
91 |
+
|
92 |
+
def get_color3_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
|
93 |
+
result = np.zeros(len(self.color3_dict.items()))
|
94 |
+
|
95 |
+
for classifier in classifiers:
|
96 |
+
id, weight = classifier
|
97 |
+
if id in self.color3_dict:
|
98 |
+
weight_index = self.color3_dict[id]
|
99 |
+
result[weight_index] = weight
|
100 |
+
return result
|
101 |
+
|
102 |
+
def get_class_labels_size(self):
|
103 |
+
return len(self.object_description_dict.items()) + len(self.color1_dict.items()) + len(self.color2_dict.items()) + len(self.color3_dict.items())
|
104 |
+
|
105 |
+
def pack_labels_to_tensor(self, num_images, object_descriptions : np.array, colors1: np.array, colors2 : np.array, colors3 : np.array) -> torch.Tensor:
|
106 |
+
num_labels = self.get_class_labels_size()
|
107 |
+
class_labels = torch.Tensor(size=(num_images, num_labels))
|
108 |
+
|
109 |
+
for batch_index in range(num_images):
|
110 |
+
offset = 0
|
111 |
+
class_labels[batch_index, offset:offset + len(self.object_description_dict)] = torch.from_numpy(object_descriptions[batch_index])
|
112 |
+
|
113 |
+
offset += len(self.object_description_dict.items())
|
114 |
+
class_labels[batch_index, offset:offset + len(self.color1_dict)] = torch.from_numpy(colors1[batch_index])
|
115 |
+
|
116 |
+
offset += len(self.color1_dict.items())
|
117 |
+
class_labels[batch_index, offset:offset + len(self.color2_dict)] = torch.from_numpy(colors2[batch_index])
|
118 |
+
|
119 |
+
offset += len(self.color2_dict.items())
|
120 |
+
class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
|
121 |
+
return class_labels
|
122 |
+
|
123 |
+
def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
|
124 |
+
color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
|
125 |
+
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
|
126 |
+
|
127 |
+
# check if the labels are the correct size
|
128 |
+
if len(object_description) != batch_size:
|
129 |
+
return None
|
130 |
+
|
131 |
+
if len(color1) != batch_size:
|
132 |
+
return None
|
133 |
+
|
134 |
+
if color2 != None and len(color2) != batch_size:
|
135 |
+
return None
|
136 |
+
|
137 |
+
if color3 != None and len(color3) != batch_size:
|
138 |
+
return None
|
139 |
+
|
140 |
+
# ok build the labels for each batch
|
141 |
+
object_descriptions = []
|
142 |
+
colors1 = []
|
143 |
+
colors2 = []
|
144 |
+
colors3 = []
|
145 |
+
|
146 |
+
for batch_index in range(batch_size):
|
147 |
+
obj_desc = self.get_object_description_weights(object_description[batch_index])
|
148 |
+
c1 = self.get_color1_weights(color1[batch_index])
|
149 |
+
|
150 |
+
if color2 != None:
|
151 |
+
c2 = self.get_color2_weights(color2[batch_index])
|
152 |
+
else:
|
153 |
+
c2 = self.get_color2_weights([])
|
154 |
+
|
155 |
+
if color3 != None:
|
156 |
+
c3 = self.get_color3_weights(color3[batch_index])
|
157 |
+
else:
|
158 |
+
c3 = self.get_color3_weights([])
|
159 |
+
|
160 |
+
object_descriptions.append(obj_desc)
|
161 |
+
colors1.append(c1)
|
162 |
+
colors2.append(c2)
|
163 |
+
colors3.append(c3)
|
164 |
+
|
165 |
+
# now put those weights into a tensor
|
166 |
+
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
|
167 |
+
class_labels = class_labels.to('cuda')
|
168 |
+
|
169 |
+
# set the inference steps
|
170 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
171 |
+
|
172 |
+
noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to('cuda')
|
173 |
+
for batch_index in range(batch_size):
|
174 |
+
for view_index in range(4):
|
175 |
+
noise = torch.randn(3, 256, 256).to('cuda')
|
176 |
+
noise_batches[batch_index, view_index] = noise
|
177 |
+
|
178 |
+
# reshape the data so it's (batch_size, 12, 256, 256)
|
179 |
+
noise_batches = torch.reshape(noise_batches, (batch_size, 12, 256, 256)).to('cuda')
|
180 |
+
|
181 |
+
# now call the model for the n interations
|
182 |
+
progress_bar = tqdm(total=num_inference_steps)
|
183 |
+
epoch = 0
|
184 |
+
for t in self.scheduler.timesteps:
|
185 |
+
progress_bar.set_description(f'Inference step {epoch}')
|
186 |
+
with torch.no_grad():
|
187 |
+
noise_residual = self.unet(noise_batches, t, class_labels=class_labels).sample
|
188 |
+
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches).prev_sample
|
189 |
+
noise_batches = previous_noisy_sample
|
190 |
+
progress_bar.update(1)
|
191 |
+
epoch = epoch + 1
|
192 |
+
|
193 |
+
# reshape the data so we get back 4 RGB images
|
194 |
+
noise_batches = torch.reshape(noise_batches, (batch_size, 4, 3, 256, 256)).to('cpu')
|
195 |
+
|
196 |
+
# convert those tensors to PIL images
|
197 |
+
output_images = []
|
198 |
+
tensor_to_pil = T.ToPILImage('RGB')
|
199 |
|
200 |
+
for batch_index in range(batch_size):
|
201 |
+
for image_index in range(4):
|
202 |
+
output_images.append(tensor_to_pil(noise_batches[batch_index, image_index]))
|
203 |
+
|
204 |
+
# for now just return the images
|
205 |
+
return output_images
|
test_pipeline.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
2 |
-
from diffusers import DDPMScheduler, UNet2DModel
|
3 |
|
4 |
-
scheduler = DDPMScheduler()
|
5 |
-
unet = UNet2DModel()
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
|
|
2 |
|
|
|
|
|
3 |
|
4 |
+
torch_device = "cuda"
|
5 |
+
|
6 |
+
pipeline = RCTDiffusionPipeline()
|
7 |
+
pipeline.print_class_tokens_to_csv()
|
8 |
+
output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
9 |
+
print('test')
|
train_model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from PIL.Image import Image
|
3 |
+
import PIL
|
4 |
+
from PIL.Image import Resampling
|
5 |
+
import numpy as np
|
6 |
+
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as T
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
|
13 |
+
def save_and_test(pipeline, epoch):
|
14 |
+
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
15 |
+
for image_index in range(len(outputs)):
|
16 |
+
file_name = f'out{image_index}_{epoch}.png'
|
17 |
+
outputs[image_index].save(file_name)
|
18 |
+
|
19 |
+
model_file = f'rct_foliage_{epoch}.pth'
|
20 |
+
pipeline.save_pretrained(model_file)
|
21 |
+
|
22 |
+
def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
|
23 |
+
dataset = load_dataset('frutiemax/rct_dataset')
|
24 |
+
dataset = dataset['train']
|
25 |
+
|
26 |
+
num_images = int(dataset.num_rows / 4)
|
27 |
+
|
28 |
+
# let's get all the entries for the 4 views split in four lists
|
29 |
+
views = []
|
30 |
+
|
31 |
+
for view_index in range(4):
|
32 |
+
entries = [entry for entry in dataset if entry['view'] == view_index]
|
33 |
+
views.append(entries)
|
34 |
+
|
35 |
+
# convert those images to 256x256 by cropping and scaling up the image
|
36 |
+
image_views = []
|
37 |
+
for view_index in range(4):
|
38 |
+
images = []
|
39 |
+
for entry in views[view_index]:
|
40 |
+
image = entry['image']
|
41 |
+
|
42 |
+
scale_factor = int(np.minimum(256 / image.width, 256 / image.height))
|
43 |
+
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
|
44 |
+
|
45 |
+
new_image = PIL.Image.new('RGB', (256, 256))
|
46 |
+
new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2)))
|
47 |
+
images.append(new_image)
|
48 |
+
image_views.append(images)
|
49 |
+
|
50 |
+
del views
|
51 |
+
|
52 |
+
# convert those views in tensors
|
53 |
+
targets = torch.Tensor(size=(num_images, 4, 3, 256, 256))
|
54 |
+
pillow_to_tensor = T.ToTensor()
|
55 |
+
|
56 |
+
for image_index in range(num_images):
|
57 |
+
for view_index in range(4):
|
58 |
+
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index])
|
59 |
+
del image_views
|
60 |
+
del entries
|
61 |
+
|
62 |
+
targets = torch.reshape(targets, (num_images, 12, 256, 256))
|
63 |
+
|
64 |
+
# get the labels
|
65 |
+
view0_entries = [row for row in dataset if row['view'] == 0]
|
66 |
+
obj_descriptions = [row['object_description'] for row in view0_entries]
|
67 |
+
colors1 = [row['color1'] for row in view0_entries]
|
68 |
+
colors2 = [row['color2'] for row in view0_entries]
|
69 |
+
colors3 = [row['color3'] for row in view0_entries]
|
70 |
+
|
71 |
+
del view0_entries
|
72 |
+
|
73 |
+
# convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
|
74 |
+
obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
|
75 |
+
colors1 = [[(color1, 1.0)] for color1 in colors1]
|
76 |
+
colors2 = [[(color2, 1.0)] for color2 in colors2]
|
77 |
+
colors3 = [[(color3, 1.0)] for color3 in colors3]
|
78 |
+
|
79 |
+
# convert those tuples in numpy arrays using the helper function of the model
|
80 |
+
model = RCTDiffusionPipeline()
|
81 |
+
obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
|
82 |
+
colors1 = [model.get_color1_weights(color1) for color1 in colors1]
|
83 |
+
colors2 = [model.get_color2_weights(color2) for color2 in colors2]
|
84 |
+
colors3 = [model.get_color3_weights(color3) for color3 in colors3]
|
85 |
+
|
86 |
+
# finally, convert those numpy arrays to a tensor
|
87 |
+
class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
|
88 |
+
del obj_descriptions
|
89 |
+
del colors1
|
90 |
+
del colors2
|
91 |
+
del colors3
|
92 |
+
del dataset
|
93 |
+
|
94 |
+
optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate)
|
95 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
96 |
+
optimizer=optimizer,
|
97 |
+
num_warmup_steps=lr_warmup_steps,
|
98 |
+
num_training_steps=num_images * epochs
|
99 |
+
)
|
100 |
+
|
101 |
+
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
102 |
+
progress_bar = tqdm(total=epochs)
|
103 |
+
for epoch in range(epochs):
|
104 |
+
# create a noisy version of each sprite
|
105 |
+
for batch_index in range(0, num_images, batch_size):
|
106 |
+
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
107 |
+
batch_end = np.minimum(num_images, batch_index + batch_size)
|
108 |
+
clean_images = targets[batch_index:batch_end].to('cuda')
|
109 |
+
batch_labels = class_labels[batch_index:batch_end].to('cuda')
|
110 |
+
|
111 |
+
noise = torch.randn(clean_images.shape).to('cuda')
|
112 |
+
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
|
113 |
+
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps)
|
114 |
+
noise_pred = model.unet(noisy_images, timesteps, batch_labels, return_dict=False)[0]
|
115 |
+
loss = F.mse_loss(noise_pred, noise)
|
116 |
+
loss.backward()
|
117 |
+
|
118 |
+
optimizer.step()
|
119 |
+
lr_scheduler.step()
|
120 |
+
optimizer.zero_grad()
|
121 |
+
|
122 |
+
if (epoch + 1) % save_model_interval == 0:
|
123 |
+
save_and_test(model, epoch)
|
124 |
+
progress_bar.update(1)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
train_model()
|