jiwan-chung commited on
Commit
5a61cb9
1 Parent(s): 61a945a

running on cpu

Browse files
Files changed (3) hide show
  1. arguments.py +0 -7
  2. load.py +1 -1
  3. run.py +2 -3
arguments.py CHANGED
@@ -37,8 +37,6 @@ def get_args():
37
  '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
38
  parser.add_argument(
39
  '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
40
- parser.add_argument(
41
- '--num-gpus', type=int, default=None, help='number of gpus. use all available if none')
42
  parser.add_argument(
43
  '--port', type=int, default=None, help="port for the demo server")
44
 
@@ -47,11 +45,6 @@ def get_args():
47
 
48
  if args.use_label_prefix:
49
  log.info(f'using label prefix')
50
- num_gpus = torch.cuda.device_count()
51
- if args.num_gpus is None:
52
- args.num_gpus = num_gpus
53
- else:
54
- args.num_gpus = min(num_gpus, args.num_gpus)
55
 
56
  if args.checkpoint is not None:
57
  args.checkpoint = str(Path(args.checkpoint).resolve())
37
  '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference")
38
  parser.add_argument(
39
  '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.')
 
 
40
  parser.add_argument(
41
  '--port', type=int, default=None, help="port for the demo server")
42
 
45
 
46
  if args.use_label_prefix:
47
  log.info(f'using label prefix')
 
 
 
 
 
48
 
49
  if args.checkpoint is not None:
50
  args.checkpoint = str(Path(args.checkpoint).resolve())
load.py CHANGED
@@ -42,7 +42,7 @@ def load_model(args, device, finetune=False):
42
  use_transformer_mapper=args.use_transformer_mapper,
43
  model_weight='None', use_label_prefix=args.use_label_prefix)
44
  ckpt = args.checkpoint + '.ckpt'
45
- state = torch.load(ckpt)
46
  policy_key = 'policy_model'
47
  if policy_key in state:
48
  policy.model.load_state_dict(state[policy_key])
42
  use_transformer_mapper=args.use_transformer_mapper,
43
  model_weight='None', use_label_prefix=args.use_label_prefix)
44
  ckpt = args.checkpoint + '.ckpt'
45
+ state = torch.load(ckpt, map_location=torch.device('cpu'))
46
  policy_key = 'policy_model'
47
  if policy_key in state:
48
  policy.model.load_state_dict(state[policy_key])
run.py CHANGED
@@ -22,16 +22,15 @@ log = logging.getLogger(__name__)
22
 
23
 
24
  def prepare(args):
25
- num_gpus = torch.cuda.device_count()
26
- log.info(f'Detect {num_gpus} GPUS')
27
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
28
  args = load_model_args(args)
29
 
30
  def load_style(args, checkpoint):
31
  model = AutoModelForCausalLM.from_pretrained(args.init_model)
32
  if checkpoint is not None and Path(checkpoint).is_file():
33
  log.info("joint model: loading pretrained style generator")
34
- state = torch.load(checkpoint)
35
  if 'global_step' in state:
36
  step = state['global_step']
37
  log.info(f'trained for {step} steps')
22
 
23
 
24
  def prepare(args):
 
 
25
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
26
+ log.info(f'Device: {device}')
27
  args = load_model_args(args)
28
 
29
  def load_style(args, checkpoint):
30
  model = AutoModelForCausalLM.from_pretrained(args.init_model)
31
  if checkpoint is not None and Path(checkpoint).is_file():
32
  log.info("joint model: loading pretrained style generator")
33
+ state = torch.load(checkpoint, map_location=torch.device('cpu'))
34
  if 'global_step' in state:
35
  step = state['global_step']
36
  log.info(f'trained for {step} steps')