EvanTHU commited on
Commit
1319952
1 Parent(s): abf246c

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +8 -0
models/unet.py CHANGED
@@ -759,24 +759,32 @@ class CondUnet1D(nn.Module):
759
  cond,
760
  cond_indices,
761
  ):
 
762
  temb = self.time_mlp(t)
763
 
764
  h = []
765
  for block1, block2, downsample in self.downs:
 
 
766
  x = block1(x, temb, cond, cond_indices)
767
  x = block2(x, temb, cond, cond_indices)
768
  h.append(x)
769
  x = downsample(x)
770
 
 
 
771
  x = self.mid_block1(x, temb, cond, cond_indices)
772
  x = self.mid_block2(x, temb, cond, cond_indices)
773
 
774
  for upsample, block1, block2 in self.ups:
775
  x = upsample(x)
776
  x = torch.cat((x, h.pop()), dim=1)
 
 
777
  x = block1(x, temb, cond, cond_indices)
778
  x = block2(x, temb, cond, cond_indices)
779
 
 
780
  x = self.final_conv(x)
781
  return x
782
 
 
759
  cond,
760
  cond_indices,
761
  ):
762
+ self.time_mlp = self.time_mlp.cuda()
763
  temb = self.time_mlp(t)
764
 
765
  h = []
766
  for block1, block2, downsample in self.downs:
767
+ block1 = block1.cuda()
768
+ block2 = block2.cuda()
769
  x = block1(x, temb, cond, cond_indices)
770
  x = block2(x, temb, cond, cond_indices)
771
  h.append(x)
772
  x = downsample(x)
773
 
774
+ self.mid_block1 = self.mid_block1.cuda()
775
+ self.mid_block2 = self.mid_block2.cuda()
776
  x = self.mid_block1(x, temb, cond, cond_indices)
777
  x = self.mid_block2(x, temb, cond, cond_indices)
778
 
779
  for upsample, block1, block2 in self.ups:
780
  x = upsample(x)
781
  x = torch.cat((x, h.pop()), dim=1)
782
+ block1 = block1.cuda()
783
+ block2 = block2.cuda()
784
  x = block1(x, temb, cond, cond_indices)
785
  x = block2(x, temb, cond, cond_indices)
786
 
787
+ self.final_conv = self.final_conv.cuda()
788
  x = self.final_conv(x)
789
  return x
790