Rodrigo_Cobo commited on
Commit
272c5b4
1 Parent(s): 83c98ec

add the option to work in CPU

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. WiggleGAN.py +6 -2
.gitignore CHANGED
@@ -5,4 +5,5 @@ Lib/*
5
  logs/*
6
  WiggleGAN_mod.py
7
  WiggleGAN_noCycle.py
8
- pyvenv.cfg
 
5
  logs/*
6
  WiggleGAN_mod.py
7
  WiggleGAN_noCycle.py
8
+ pyvenv.cfg
9
+ py
WiggleGAN.py CHANGED
@@ -783,9 +783,13 @@ class WiggleGAN(object):
783
  def load(self):
784
  save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
785
 
786
- self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl')))
 
 
 
 
787
  if not self.wiggle:
788
- self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl')))
789
 
790
  def wiggleEf(self):
791
  seed, epoch = self.seed_load.split('_')
783
  def load(self):
784
  save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
785
 
786
+ map_loc=None
787
+ if not torch.cuda.is_available():
788
+ map_loc='cpu'
789
+
790
+ self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl'), map_location=map_loc))
791
  if not self.wiggle:
792
+ self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl'), map_location=map_loc))
793
 
794
  def wiggleEf(self):
795
  seed, epoch = self.seed_load.split('_')