Spaces:
Running
on
Zero
Running
on
Zero
Update models/unet.py
Browse files- models/unet.py +8 -0
models/unet.py
CHANGED
@@ -759,24 +759,32 @@ class CondUnet1D(nn.Module):
|
|
759 |
cond,
|
760 |
cond_indices,
|
761 |
):
|
|
|
762 |
temb = self.time_mlp(t)
|
763 |
|
764 |
h = []
|
765 |
for block1, block2, downsample in self.downs:
|
|
|
|
|
766 |
x = block1(x, temb, cond, cond_indices)
|
767 |
x = block2(x, temb, cond, cond_indices)
|
768 |
h.append(x)
|
769 |
x = downsample(x)
|
770 |
|
|
|
|
|
771 |
x = self.mid_block1(x, temb, cond, cond_indices)
|
772 |
x = self.mid_block2(x, temb, cond, cond_indices)
|
773 |
|
774 |
for upsample, block1, block2 in self.ups:
|
775 |
x = upsample(x)
|
776 |
x = torch.cat((x, h.pop()), dim=1)
|
|
|
|
|
777 |
x = block1(x, temb, cond, cond_indices)
|
778 |
x = block2(x, temb, cond, cond_indices)
|
779 |
|
|
|
780 |
x = self.final_conv(x)
|
781 |
return x
|
782 |
|
|
|
759 |
cond,
|
760 |
cond_indices,
|
761 |
):
|
762 |
+
self.time_mlp = self.time_mlp.cuda()
|
763 |
temb = self.time_mlp(t)
|
764 |
|
765 |
h = []
|
766 |
for block1, block2, downsample in self.downs:
|
767 |
+
block1 = block1.cuda()
|
768 |
+
block2 = block2.cuda()
|
769 |
x = block1(x, temb, cond, cond_indices)
|
770 |
x = block2(x, temb, cond, cond_indices)
|
771 |
h.append(x)
|
772 |
x = downsample(x)
|
773 |
|
774 |
+
self.mid_block1 = self.mid_block1.cuda()
|
775 |
+
self.mid_block2 = self.mid_block2.cuda()
|
776 |
x = self.mid_block1(x, temb, cond, cond_indices)
|
777 |
x = self.mid_block2(x, temb, cond, cond_indices)
|
778 |
|
779 |
for upsample, block1, block2 in self.ups:
|
780 |
x = upsample(x)
|
781 |
x = torch.cat((x, h.pop()), dim=1)
|
782 |
+
block1 = block1.cuda()
|
783 |
+
block2 = block2.cuda()
|
784 |
x = block1(x, temb, cond, cond_indices)
|
785 |
x = block2(x, temb, cond, cond_indices)
|
786 |
|
787 |
+
self.final_conv = self.final_conv.cuda()
|
788 |
x = self.final_conv(x)
|
789 |
return x
|
790 |
|