yhavinga commited on
Commit
f5a218d
β€’
1 Parent(s): cda95dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -124,6 +124,9 @@ def main():
124
  length_penalty = st.sidebar.number_input(
125
  "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
126
  )
 
 
 
127
  st.sidebar.markdown(
128
  """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
129
  and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
@@ -132,6 +135,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
132
  params = {
133
  "num_beams": num_beams,
134
  "num_beam_groups": num_beam_groups,
 
135
  "length_penalty": length_penalty,
136
  "early_stopping": True,
137
  }
 
124
  length_penalty = st.sidebar.number_input(
125
  "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
126
  )
127
+ diversity_penalty = st.sidebar.number_input(
128
+ "Diversity penalty", min_value=0.0, max_value=2.0, value=0.0, step=0.1
129
+ )
130
  st.sidebar.markdown(
131
  """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
132
  and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
 
135
  params = {
136
  "num_beams": num_beams,
137
  "num_beam_groups": num_beam_groups,
138
+ "diversity_penalty": diversity_penalty if num_beam_groups > 1 else 0.0,
139
  "length_penalty": length_penalty,
140
  "early_stopping": True,
141
  }