adymaharana commited on
Commit
77e955b
1 Parent(s): da4893d

cuda device

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. dalle/models/__init__.py +2 -2
app.py CHANGED
@@ -66,7 +66,8 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
66
 
67
 
68
  def main(args):
69
- device = 'cuda:0'
 
70
 
71
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
72
 
 
66
 
67
 
68
  def main(args):
69
+ #device = 'cuda:0'
70
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
71
 
72
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
73
 
dalle/models/__init__.py CHANGED
@@ -1193,9 +1193,9 @@ class StoryDalle(Dalle):
1193
  print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
1194
  # model.from_ckpt(args.model_name_or_path)
1195
  try:
1196
- model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
1197
  except KeyError:
1198
- model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
1199
  else:
1200
  model = cls(config_update)
1201
  print(model.cross_attention_idxs)
 
1193
  print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
1194
  # model.from_ckpt(args.model_name_or_path)
1195
  try:
1196
+ model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
1197
  except KeyError:
1198
+ model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['model_state_dict'])
1199
  else:
1200
  model = cls(config_update)
1201
  print(model.cross_attention_idxs)