ksuzuki01 commited on
Commit
0c074a8
·
verified ·
1 Parent(s): 0ff4d0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -17,6 +17,16 @@ def preprocess(text):
17
  def postprocess(text):
18
  return text.replace("<LB>", "\n")
19
 
 
 
 
 
 
 
 
 
 
 
20
  def generate(input_text):
21
  input_text += tokenizer.eos_token
22
  input_text = preprocess(input_text)
@@ -26,9 +36,7 @@ def generate(input_text):
26
 
27
  output_ids = model.generate(
28
  token_ids.to(model.device),
29
- max_new_tokens=256,
30
- pad_token_id=tokenizer.pad_token_id,
31
- eos_token_id=tokenizer.eos_token_id,
32
  )
33
  output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
34
  return postprocess(output)
 
17
  def postprocess(text):
18
  return text.replace("<LB>", "\n")
19
 
20
+ generation_config = {
21
+ "max_new_tokens": 256,
22
+ "num_beams": 3,
23
+ "num_return_sequences": 1,
24
+ "early_stopping": True,
25
+ "eos_token_id": tokenizer.eos_token_id,
26
+ "pad_token_id": tokenizer.pad_token_id,
27
+ "repetition_penalty": 3.0
28
+ }
29
+
30
  def generate(input_text):
31
  input_text += tokenizer.eos_token
32
  input_text = preprocess(input_text)
 
36
 
37
  output_ids = model.generate(
38
  token_ids.to(model.device),
39
+ **generation_config
 
 
40
  )
41
  output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
42
  return postprocess(output)