frutiemax commited on
Commit
04d70cd
1 Parent(s): 88deab4

Fix training

Browse files
Files changed (2) hide show
  1. rct_diffusion_pipeline.py +2 -2
  2. train_model.py +6 -3
rct_diffusion_pipeline.py CHANGED
@@ -28,7 +28,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
28
  self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
29
  down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
30
  up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
- block_out_channels=(12, 24, 30), norm_num_groups=6)
32
 
33
  self.unet.to(device='cuda', dtype=torch.float16)
34
 
@@ -167,7 +167,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
167
  colors3.append(c3)
168
 
169
  # now put those weights into a tensor
170
- class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3)
171
 
172
  # we need those class labels for the 12 channels
173
  #new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
 
28
  self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
29
  down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
30
  up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
+ block_out_channels=(64, 128, 256), norm_num_groups=32)
32
 
33
  self.unet.to(device='cuda', dtype=torch.float16)
34
 
 
167
  colors3.append(c3)
168
 
169
  # now put those weights into a tensor
170
+ class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
171
 
172
  # we need those class labels for the 12 channels
173
  #new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
train_model.py CHANGED
@@ -108,11 +108,14 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
108
  clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
109
  clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
110
 
111
- noise = torch.randn(clean_images.shape).to('cuda')
112
- timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, )).to('cuda')
 
113
  noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
114
  noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
115
- loss = F.mse_loss(noise_pred, noise)
 
 
116
  loss.backward()
117
 
118
  optimizer.step()
 
108
  clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
109
  clean_images = torch.reshape(clean_images, (batch_size, 12, 256, 256))
110
 
111
+ noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
112
+ timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_size, ))
113
+ timesteps = timesteps.to(dtype=torch.int, device='cuda')
114
  noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
115
  noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
116
+
117
+ noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
118
+ loss = F.mse_loss(noise_pred, noise).to(device='cuda', dtype=torch.float16)
119
  loss.backward()
120
 
121
  optimizer.step()