EvanTHU commited on
Commit
abf246c
1 Parent(s): 0ef5446

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +1 -0
models/unet.py CHANGED
@@ -263,6 +263,7 @@ class TimestepEmbedder(nn.Module):
263
  self.register_buffer("pe", pe)
264
 
265
  def forward(self, x):
 
266
  return self.pe[x]
267
 
268
 
 
263
  self.register_buffer("pe", pe)
264
 
265
  def forward(self, x):
266
+ self.pe = self.pe.cuda()
267
  return self.pe[x]
268
 
269