lhoestq HF staff commited on
Commit
6b58b86
1 Parent(s): e75be2d

smaller batch_size, higher temperature

Browse files
Files changed (1) hide show
  1. generate.py +2 -2
generate.py CHANGED
@@ -30,14 +30,14 @@ if torch.backends.mps.is_available():
30
  else:
31
  device = "cuda"
32
  model_id = "google/gemma-2b-it"
33
- batch_size = 10
34
 
35
  model = models.transformers(model_id, device=device)
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
39
  low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
40
- high_temperature_sampler = PenalizedMultinomialSampler(temperature=1.1)
41
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
42
  sampler.set_max_repeats(empty_tokens, 1)
43
  disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
 
30
  else:
31
  device = "cuda"
32
  model_id = "google/gemma-2b-it"
33
+ batch_size = 4
34
 
35
  model = models.transformers(model_id, device=device)
36
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
39
  low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
40
+ high_temperature_sampler = PenalizedMultinomialSampler(temperature=1.5)
41
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
42
  sampler.set_max_repeats(empty_tokens, 1)
43
  disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now