Fix training
Browse files- rct_diffusion_pipeline.py +2 -2
- train_model.py +6 -3
rct_diffusion_pipeline.py
CHANGED
@@ -28,7 +28,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
28 |
self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
|
29 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
30 |
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
|
31 |
-
block_out_channels=(
|
32 |
|
33 |
self.unet.to(device='cuda', dtype=torch.float16)
|
34 |
|
@@ -167,7 +167,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
167 |
colors3.append(c3)
|
168 |
|
169 |
# now put those weights into a tensor
|
170 |
-
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
|
171 |
|
172 |
# we need those class labels for the 12 channels
|
173 |
#new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
|
|
|
28 |
self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
|
29 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
30 |
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
|
31 |
+
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
32 |
|
33 |
self.unet.to(device='cuda', dtype=torch.float16)
|
34 |
|
|
|
167 |
colors3.append(c3)
|
168 |
|
169 |
# now put those weights into a tensor
|
170 |
+
class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
|
171 |
|
172 |
# we need those class labels for the 12 channels
|
173 |
#new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
|
train_model.py
CHANGED
@@ -108,11 +108,14 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
108 |
clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
|
109 |
clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
|
110 |
|
111 |
-
noise = torch.randn(clean_images.shape).to('cuda')
|
112 |
-
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, ))
|
|
|
113 |
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
114 |
noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
115 |
-
|
|
|
|
|
116 |
loss.backward()
|
117 |
|
118 |
optimizer.step()
|
|
|
108 |
clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
|
109 |
clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
|
110 |
|
111 |
+
noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
|
112 |
+
timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, ))
|
113 |
+
timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
114 |
noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
115 |
noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
116 |
+
|
117 |
+
noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
|
118 |
+
loss = F.mse_loss(noise_pred, noise).to(device='cuda', dtype=torch.float16)
|
119 |
loss.backward()
|
120 |
|
121 |
optimizer.step()
|