lhoestq HF staff commited on
Commit
028b74d
1 Parent(s): 4f83ec0
Files changed (1) hide show
  1. generate.py +5 -8
generate.py CHANGED
@@ -19,17 +19,14 @@ from utils import StringIteratorIO
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
- if torch.cuda.is_available():
23
- device = "cuda"
24
- elif torch.backends.mps.is_available():
25
- device = "mps"
26
- else:
27
- raise RuntimeError("couldn't find cuda or mps")
28
-
29
  logger.warning("Loading model...")
30
  model_id = "google/gemma-2b-it"
31
  # model_id = "Qwen/Qwen1.5-0.5B-Chat"
32
- model = models.transformers(model_id, device=device)
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  sampler = PenalizedMultinomialSampler()
35
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
 
 
 
 
 
 
 
22
  logger.warning("Loading model...")
23
  model_id = "google/gemma-2b-it"
24
  # model_id = "Qwen/Qwen1.5-0.5B-Chat"
25
+ if torch.backends.mps.is_available():
26
+ model = models.transformers(model_id, device="mps")
27
+ else:
28
+ model = models.transformers(model_id, device="cuda")
29
+
30
  tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  sampler = PenalizedMultinomialSampler()
32
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id]).strip()]