zen21 commited on
Commit
b9015b6
1 Parent(s): f4d1685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -119,9 +119,9 @@ class UNet(nn.Module):
119
  x = self.up3(x, x1, t)
120
  output = self.outc(x)
121
  return output
122
- device = 'cpu'
123
  model = UNet(device = device).to(device)
124
- model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth'))
125
  img_size = 64
126
  class Diffusion():
127
  def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):
 
119
  x = self.up3(x, x1, t)
120
  output = self.outc(x)
121
  return output
122
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
123
  model = UNet(device = device).to(device)
124
+ model.load_state_dict(torch.load('Model_Saved_States/diffusion_64.pth', map_location=torch.device(device)))
125
  img_size = 64
126
  class Diffusion():
127
  def __init__(self, time_steps = 500, beta_start = 0.0001, beta_stop = 0.02, image_size = 64, device = device):