lhoestq HF staff commited on
Commit
3642076
1 Parent(s): 6b58b86

default temperature

Browse files
Files changed (1) hide show
  1. generate.py +1 -2
generate.py CHANGED
@@ -37,7 +37,6 @@ model = models.transformers(model_id, device=device)
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
39
  low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
40
- high_temperature_sampler = PenalizedMultinomialSampler(temperature=1.5)
41
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
42
  sampler.set_max_repeats(empty_tokens, 1)
43
  disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
@@ -61,7 +60,7 @@ class Dataset(BaseModel):
61
  data: conlist(Sample, min_length=2, max_length=3) # type: ignore
62
 
63
 
64
- samples_generator_template = generate.json(model, Dataset, sampler=high_temperature_sampler)
65
 
66
  class Columns(BaseModel):
67
  columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  sampler = PenalizedMultinomialSampler()
39
  low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3)
 
40
  empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()]
41
  sampler.set_max_repeats(empty_tokens, 1)
42
  disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now
 
60
  data: conlist(Sample, min_length=2, max_length=3) # type: ignore
61
 
62
 
63
+ samples_generator_template = generate.json(model, Dataset, sampler=sampler)
64
 
65
  class Columns(BaseModel):
66
  columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore