ksuzuki01 commited on
Commit
8bbc9d2
1 Parent(s): 1a6b2d0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -4
README.md CHANGED
@@ -65,7 +65,17 @@ def preprocess(text):
65
 
66
  def postprocess(text):
67
  return text.replace("<LB>", "\n")
68
-
 
 
 
 
 
 
 
 
 
 
69
  input_text += "<SEP>"
70
  input_text = preprocess(input_text)
71
 
@@ -74,9 +84,7 @@ with torch.no_grad():
74
 
75
  output_ids = model.generate(
76
  token_ids.to(model.device),
77
- max_new_tokens=256,
78
- pad_token_id=tokenizer.pad_token_id,
79
- eos_token_id=tokenizer.eos_token_id,
80
  )
81
  output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
82
  output = postprocess(output)
 
65
 
66
  def postprocess(text):
67
  return text.replace("<LB>", "\n")
68
+
69
+ generation_config = {
70
+ "max_new_tokens": 256,
71
+ "num_beams": 3,
72
+ "num_return_sequences": 1,
73
+ "early_stopping": True,
74
+ "eos_token_id": tokenizer.eos_token_id,
75
+ "pad_token_id": tokenizer.pad_token_id,
76
+ "repetition_penalty": 3.0
77
+ }
78
+
79
  input_text += "<SEP>"
80
  input_text = preprocess(input_text)
81
 
 
84
 
85
  output_ids = model.generate(
86
  token_ids.to(model.device),
87
+ **generation_config
 
 
88
  )
89
  output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1) :], skip_special_tokens=True)
90
  output = postprocess(output)