lhoestq HF staff commited on
Commit
a563465
1 Parent(s): 7375eb9
Files changed (1) hide show
  1. generate.py +7 -5
generate.py CHANGED
@@ -20,12 +20,14 @@ logger = logging.getLogger(__name__)
20
 
21
 
22
  logger.warning("Loading model...")
23
- model_id = "google/gemma-2b-it"
24
- # model_id = "Qwen/Qwen1.5-0.5B-Chat"
25
  if torch.backends.mps.is_available():
26
- model = models.transformers(model_id, device="mps")
 
27
  else:
28
- model = models.transformers(model_id, device="cuda")
 
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  sampler = PenalizedMultinomialSampler()
@@ -69,7 +71,7 @@ def get_samples_generator(new_fields: list[str]) -> SequenceGenerator:
69
  fsm=fsm,
70
  model=samples_generator_template.model,
71
  sampler=samples_generator_template.sampler,
72
- device=samples_generator_template.device
73
  )
74
 
75
 
 
20
 
21
 
22
  logger.warning("Loading model...")
23
+ # model_id = "google/gemma-2b-it"
24
+ model_id = "Qwen/Qwen1.5-0.5B-Chat"
25
  if torch.backends.mps.is_available():
26
+ device = "mps"
27
+ model = models.transformers(model_id, device=device)
28
  else:
29
+ device = "cuda"
30
+ model = models.transformers(model_id, device=device)
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  sampler = PenalizedMultinomialSampler()
 
71
  fsm=fsm,
72
  model=samples_generator_template.model,
73
  sampler=samples_generator_template.sampler,
74
+ device=device
75
  )
76
 
77