EvanTHU commited on
Commit
df42902
1 Parent(s): eee8228

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +2 -0
models/unet.py CHANGED
@@ -273,6 +273,7 @@ class Downsample1d(nn.Module):
273
  self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
274
 
275
  def forward(self, x):
 
276
  return self.conv(x)
277
 
278
 
@@ -283,6 +284,7 @@ class Upsample1d(nn.Module):
283
  self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1)
284
 
285
  def forward(self, x):
 
286
  return self.conv(x)
287
 
288
 
 
273
  self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
274
 
275
  def forward(self, x):
276
+ self.conv = self.conv.cuda()
277
  return self.conv(x)
278
 
279
 
 
284
  self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1)
285
 
286
  def forward(self, x):
287
+ self.conv = self.conv.cuda()
288
  return self.conv(x)
289
 
290