frutiemax commited on
Commit
d751051
·
1 Parent(s): 2a7e546

Training code

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +198 -8
  2. test_pipeline.py +6 -5
  3. train_model.py +128 -0
rct_diffusion_pipeline.py CHANGED
@@ -1,15 +1,205 @@
1
  from diffusers import DiffusionPipeline
 
 
2
  import torch
 
 
 
 
 
 
 
3
 
4
  class RCTDiffusionPipeline(DiffusionPipeline):
5
- def __init__(self, unet, scheduler):
6
  super().__init__()
7
- self.register_modules(unet=unet, scheduler=scheduler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def __call__(self):
10
- image = torch.randn((1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size))
11
- timestep = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- model_output = self.unet(image, timestep).sample
14
- scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
15
- return scheduler_output
 
 
 
 
1
  from diffusers import DiffusionPipeline
2
+ from diffusers import DDPMPipeline
3
+ from diffusers import DDPMScheduler, UNet2DModel
4
  import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ from transformers import AutoTokenizer
8
+ from datasets import load_dataset
9
+ import numpy as np
10
+ import pandas as pd
11
+ from tqdm.auto import tqdm
12
 
13
  class RCTDiffusionPipeline(DiffusionPipeline):
14
+ def __init__(self):
15
  super().__init__()
16
+
17
+ # dictionnary that keeps the different classes of object description, color1, color2 and color3
18
+ self.object_description_dict = {}
19
+ self.color1_dict = {}
20
+ self.color2_dict = {}
21
+ self.color3_dict = {}
22
+ self.load_dictionaries_from_dataset()
23
+
24
+ self.scheduler = DDPMScheduler()
25
+
26
+ # the number of hidden features is dependant on the loaded dictionaries!
27
+ self.unet = UNet2DModel(sample_size=256, in_channels=12, out_channels=12, \
28
+ down_block_types=('DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D'), up_block_types=('UpBlock2D', 'UpBlock2D', 'AttnUpBlock2D'), \
29
+ block_out_channels=(16, 32, 64), norm_num_groups=16)
30
+
31
+ self.unet.to('cuda')
32
 
33
+ def load_dictionaries_from_dataset(self):
34
+ dataset = load_dataset('frutiemax/rct_dataset')
35
+ dataset = dataset['train']
36
+
37
+ for row in dataset:
38
+ if not row['object_description'] in self.object_description_dict:
39
+ self.object_description_dict[row['object_description']] = len(self.object_description_dict)
40
+ if not row['color1'] in self.color1_dict and row['color1'] != 'none':
41
+ self.color1_dict[row['color1']] = len(self.color1_dict)
42
+ if not row['color2'] in self.color2_dict and row['color2'] != 'none':
43
+ self.color2_dict[row['color2']] = len(self.color2_dict)
44
+ if not row['color3'] in self.color3_dict and row['color3'] != 'none':
45
+ self.color3_dict[row['color3']] = len(self.color3_dict)
46
+
47
+ # helper functions to know the classes
48
+ def print_class_tokens_to_csv(self):
49
+ object_descriptions = pd.DataFrame(self.object_description_dict.items())
50
+ object_descriptions.to_csv('object_descriptions_tokens.csv')
51
+
52
+ color1 = pd.DataFrame(self.color1_dict.items())
53
+ color1.to_csv('color1_tokens.csv')
54
+
55
+ color2 = pd.DataFrame(self.color2_dict.items())
56
+ color2.to_csv('color2_tokens.csv')
57
+
58
+ color3 = pd.DataFrame(self.color3_dict.items())
59
+ color3.to_csv('color3_tokens.csv')
60
+
61
+ # helper functions to build weight tables
62
+ def get_object_description_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
63
+ result = np.zeros(len(self.object_description_dict.items()))
64
+
65
+ for classifier in classifiers:
66
+ id, weight = classifier
67
+ if id in self.object_description_dict:
68
+ weight_index = self.object_description_dict[id]
69
+ result[weight_index] = weight
70
+ return result
71
+
72
+ def get_color1_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
73
+ result = np.zeros(len(self.color1_dict.items()))
74
+
75
+ for classifier in classifiers:
76
+ id, weight = classifier
77
+ if id in self.color1_dict:
78
+ weight_index = self.color1_dict[id]
79
+ result[weight_index] = weight
80
+ return result
81
+
82
+ def get_color2_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
83
+ result = np.zeros(len(self.color2_dict.items()))
84
+
85
+ for classifier in classifiers:
86
+ id, weight = classifier
87
+ if id in self.color2_dict:
88
+ weight_index = self.color2_dict[id]
89
+ result[weight_index] = weight
90
+ return result
91
+
92
+ def get_color3_weights(self, classifiers : list[tuple[str, float]]) -> np.array:
93
+ result = np.zeros(len(self.color3_dict.items()))
94
+
95
+ for classifier in classifiers:
96
+ id, weight = classifier
97
+ if id in self.color3_dict:
98
+ weight_index = self.color3_dict[id]
99
+ result[weight_index] = weight
100
+ return result
101
+
102
+ def get_class_labels_size(self):
103
+ return len(self.object_description_dict.items()) + len(self.color1_dict.items()) + len(self.color2_dict.items()) + len(self.color3_dict.items())
104
+
105
+ def pack_labels_to_tensor(self, num_images, object_descriptions : np.array, colors1: np.array, colors2 : np.array, colors3 : np.array) -> torch.Tensor:
106
+ num_labels = self.get_class_labels_size()
107
+ class_labels = torch.Tensor(size=(num_images, num_labels))
108
+
109
+ for batch_index in range(num_images):
110
+ offset = 0
111
+ class_labels[batch_index, offset:offset + len(self.object_description_dict)] = torch.from_numpy(object_descriptions[batch_index])
112
+
113
+ offset += len(self.object_description_dict.items())
114
+ class_labels[batch_index, offset:offset + len(self.color1_dict)] = torch.from_numpy(colors1[batch_index])
115
+
116
+ offset += len(self.color1_dict.items())
117
+ class_labels[batch_index, offset:offset + len(self.color2_dict)] = torch.from_numpy(colors2[batch_index])
118
+
119
+ offset += len(self.color2_dict.items())
120
+ class_labels[batch_index, offset:offset + len(self.color3_dict)] = torch.from_numpy(colors3[batch_index])
121
+ return class_labels
122
+
123
+ def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
124
+ color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
125
+ batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
126
+
127
+ # check if the labels are the correct size
128
+ if len(object_description) != batch_size:
129
+ return None
130
+
131
+ if len(color1) != batch_size:
132
+ return None
133
+
134
+ if color2 != None and len(color2) != batch_size:
135
+ return None
136
+
137
+ if color3 != None and len(color3) != batch_size:
138
+ return None
139
+
140
+ # ok build the labels for each batch
141
+ object_descriptions = []
142
+ colors1 = []
143
+ colors2 = []
144
+ colors3 = []
145
+
146
+ for batch_index in range(batch_size):
147
+ obj_desc = self.get_object_description_weights(object_description[batch_index])
148
+ c1 = self.get_color1_weights(color1[batch_index])
149
+
150
+ if color2 != None:
151
+ c2 = self.get_color2_weights(color2[batch_index])
152
+ else:
153
+ c2 = self.get_color2_weights([])
154
+
155
+ if color3 != None:
156
+ c3 = self.get_color3_weights(color3[batch_index])
157
+ else:
158
+ c3 = self.get_color3_weights([])
159
+
160
+ object_descriptions.append(obj_desc)
161
+ colors1.append(c1)
162
+ colors2.append(c2)
163
+ colors3.append(c3)
164
+
165
+ # now put those weights into a tensor
166
+ class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
167
+ class_labels = class_labels.to('cuda')
168
+
169
+ # set the inference steps
170
+ self.scheduler.set_timesteps(num_inference_steps)
171
+
172
+ noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to('cuda')
173
+ for batch_index in range(batch_size):
174
+ for view_index in range(4):
175
+ noise = torch.randn(3, 256, 256).to('cuda')
176
+ noise_batches[batch_index, view_index] = noise
177
+
178
+ # reshape the data so it's (batch_size, 12, 256, 256)
179
+ noise_batches = torch.reshape(noise_batches, (batch_size, 12, 256, 256)).to('cuda')
180
+
181
+ # now call the model for the n interations
182
+ progress_bar = tqdm(total=num_inference_steps)
183
+ epoch = 0
184
+ for t in self.scheduler.timesteps:
185
+ progress_bar.set_description(f'Inference step {epoch}')
186
+ with torch.no_grad():
187
+ noise_residual = self.unet(noise_batches, t, class_labels=class_labels).sample
188
+ previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches).prev_sample
189
+ noise_batches = previous_noisy_sample
190
+ progress_bar.update(1)
191
+ epoch = epoch + 1
192
+
193
+ # reshape the data so we get back 4 RGB images
194
+ noise_batches = torch.reshape(noise_batches, (batch_size, 4, 3, 256, 256)).to('cpu')
195
+
196
+ # convert those tensors to PIL images
197
+ output_images = []
198
+ tensor_to_pil = T.ToPILImage('RGB')
199
 
200
+ for batch_index in range(batch_size):
201
+ for image_index in range(4):
202
+ output_images.append(tensor_to_pil(noise_batches[batch_index, image_index]))
203
+
204
+ # for now just return the images
205
+ return output_images
test_pipeline.py CHANGED
@@ -1,8 +1,9 @@
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
- from diffusers import DDPMScheduler, UNet2DModel
3
 
4
- scheduler = DDPMScheduler()
5
- unet = UNet2DModel()
6
 
7
- pipeline = RCTDiffusionPipeline(unet=unet, scheduler=scheduler)
8
- output = pipeline()
 
 
 
 
 
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
 
2
 
 
 
3
 
4
+ torch_device = "cuda"
5
+
6
+ pipeline = RCTDiffusionPipeline()
7
+ pipeline.print_class_tokens_to_csv()
8
+ output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
9
+ print('test')
train_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from PIL.Image import Image
3
+ import PIL
4
+ from PIL.Image import Resampling
5
+ import numpy as np
6
+ from rct_diffusion_pipeline import RCTDiffusionPipeline
7
+ import torch
8
+ import torchvision.transforms as T
9
+ import torch.nn.functional as F
10
+ from diffusers.optimization import get_cosine_schedule_with_warmup
11
+ from tqdm.auto import tqdm
12
+
13
+ def save_and_test(pipeline, epoch):
14
+ outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
15
+ for image_index in range(len(outputs)):
16
+ file_name = f'out{image_index}_{epoch}.png'
17
+ outputs[image_index].save(file_name)
18
+
19
+ model_file = f'rct_foliage_{epoch}.pth'
20
+ pipeline.save_pretrained(model_file)
21
+
22
+ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
23
+ dataset = load_dataset('frutiemax/rct_dataset')
24
+ dataset = dataset['train']
25
+
26
+ num_images = int(dataset.num_rows / 4)
27
+
28
+ # let's get all the entries for the 4 views split in four lists
29
+ views = []
30
+
31
+ for view_index in range(4):
32
+ entries = [entry for entry in dataset if entry['view'] == view_index]
33
+ views.append(entries)
34
+
35
+ # convert those images to 256x256 by cropping and scaling up the image
36
+ image_views = []
37
+ for view_index in range(4):
38
+ images = []
39
+ for entry in views[view_index]:
40
+ image = entry['image']
41
+
42
+ scale_factor = int(np.minimum(256 / image.width, 256 / image.height))
43
+ image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
44
+
45
+ new_image = PIL.Image.new('RGB', (256, 256))
46
+ new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2)))
47
+ images.append(new_image)
48
+ image_views.append(images)
49
+
50
+ del views
51
+
52
+ # convert those views in tensors
53
+ targets = torch.Tensor(size=(num_images, 4, 3, 256, 256))
54
+ pillow_to_tensor = T.ToTensor()
55
+
56
+ for image_index in range(num_images):
57
+ for view_index in range(4):
58
+ targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index])
59
+ del image_views
60
+ del entries
61
+
62
+ targets = torch.reshape(targets, (num_images, 12, 256, 256))
63
+
64
+ # get the labels
65
+ view0_entries = [row for row in dataset if row['view'] == 0]
66
+ obj_descriptions = [row['object_description'] for row in view0_entries]
67
+ colors1 = [row['color1'] for row in view0_entries]
68
+ colors2 = [row['color2'] for row in view0_entries]
69
+ colors3 = [row['color3'] for row in view0_entries]
70
+
71
+ del view0_entries
72
+
73
+ # convert those descriptions, color1, color2 and color3 to a list of tuple with label and weight=1.0
74
+ obj_descriptions = [[(obj_desc, 1.0)] for obj_desc in obj_descriptions]
75
+ colors1 = [[(color1, 1.0)] for color1 in colors1]
76
+ colors2 = [[(color2, 1.0)] for color2 in colors2]
77
+ colors3 = [[(color3, 1.0)] for color3 in colors3]
78
+
79
+ # convert those tuples in numpy arrays using the helper function of the model
80
+ model = RCTDiffusionPipeline()
81
+ obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
82
+ colors1 = [model.get_color1_weights(color1) for color1 in colors1]
83
+ colors2 = [model.get_color2_weights(color2) for color2 in colors2]
84
+ colors3 = [model.get_color3_weights(color3) for color3 in colors3]
85
+
86
+ # finally, convert those numpy arrays to a tensor
87
+ class_labels = model.pack_labels_to_tensor(num_images, obj_descriptions, colors1, colors2, colors3)
88
+ del obj_descriptions
89
+ del colors1
90
+ del colors2
91
+ del colors3
92
+ del dataset
93
+
94
+ optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate)
95
+ lr_scheduler = get_cosine_schedule_with_warmup(
96
+ optimizer=optimizer,
97
+ num_warmup_steps=lr_warmup_steps,
98
+ num_training_steps=num_images * epochs
99
+ )
100
+
101
+ # lets train for 100 epoch for each sprite in the dataset with a random noise level
102
+ progress_bar = tqdm(total=epochs)
103
+ for epoch in range(epochs):
104
+ # create a noisy version of each sprite
105
+ for batch_index in range(0, num_images, batch_size):
106
+ progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
107
+ batch_end = np.minimum(num_images, batch_index + batch_size)
108
+ clean_images = targets[batch_index:batch_end].to('cuda')
109
+ batch_labels = class_labels[batch_index:batch_end].to('cuda')
110
+
111
+ noise = torch.randn(clean_images.shape).to('cuda')
112
+ timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
113
+ noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps)
114
+ noise_pred = model.unet(noisy_images, timesteps, batch_labels, return_dict=False)[0]
115
+ loss = F.mse_loss(noise_pred, noise)
116
+ loss.backward()
117
+
118
+ optimizer.step()
119
+ lr_scheduler.step()
120
+ optimizer.zero_grad()
121
+
122
+ if (epoch + 1) % save_model_interval == 0:
123
+ save_and_test(model, epoch)
124
+ progress_bar.update(1)
125
+
126
+
127
+ if __name__ == '__main__':
128
+ train_model()