Yeb Havinga commited on
Commit
58040ab
1 Parent(s): 833d73c

Make generation parameters configurable

Browse files
Files changed (1) hide show
  1. app.py +51 -24
app.py CHANGED
@@ -17,26 +17,6 @@ Met een snelheid van 319 Tb/s zouden 57.000 films per seconde kunnen worden gedo
17
  Voorlopig lijkt de techniek overigens niet meer te zijn dan een experiment, vanwege de hoge kosten die ermee zijn gemoeid. In Nederland zijn providers bezig met het aanleggen van internet met een hoge downloadsnelheid. Deze glasvezelnetwerken bereiken snelheden tot 1 gigabit per seconde.
18
  """
19
 
20
- generator_kwargs_beam = {
21
- "max_length": 142,
22
- "min_length": 75,
23
- "no_repeat_ngram_size": 2,
24
- "early_stopping": True,
25
- "num_beams": 5,
26
- "length_penalty": 1.5,
27
- "num_return_sequences": 1,
28
- }
29
-
30
- generator_kwargs_top_k = {
31
- "max_length": 142,
32
- "min_length": 75,
33
- "no_repeat_ngram_size": 2,
34
- "do_sample": True,
35
- "top_k": 60,
36
- "top_p": 0.95,
37
- "num_return_sequences": 1,
38
- }
39
-
40
 
41
  class TextSummarizer:
42
  def __init__(self):
@@ -103,16 +83,63 @@ def main():
103
  transformer model. Please refer to the [model page](https://huggingface.co/flax-community/t5-base-dutch-demo) for more information.
104
  """
105
  )
106
- st.sidebar.title("Mode:")
 
 
 
 
 
 
 
 
 
 
107
 
108
  if sampling_mode := st.sidebar.selectbox(
109
  "select a Mode", index=0, options=["Beam Search", "Top-k Sampling"]
110
  ):
111
  if sampling_mode == "Beam Search":
112
- params = generator_kwargs_beam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  else:
114
- params = generator_kwargs_top_k
115
- st.sidebar.json(params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  input_text = st.text_area("Enter a Dutch news text", DEFAULT_TEXT, height=500)
118
 
 
17
  Voorlopig lijkt de techniek overigens niet meer te zijn dan een experiment, vanwege de hoge kosten die ermee zijn gemoeid. In Nederland zijn providers bezig met het aanleggen van internet met een hoge downloadsnelheid. Deze glasvezelnetwerken bereiken snelheden tot 1 gigabit per seconde.
18
  """
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class TextSummarizer:
22
  def __init__(self):
 
83
  transformer model. Please refer to the [model page](https://huggingface.co/flax-community/t5-base-dutch-demo) for more information.
84
  """
85
  )
86
+ st.sidebar.title("Parameters:")
87
+
88
+ min_length = st.sidebar.number_input(
89
+ "Min length", min_value=10, max_value=150, value=30
90
+ )
91
+ max_length = st.sidebar.number_input(
92
+ "Max length", min_value=50, max_value=250, value=142
93
+ )
94
+ no_repeat_ngram_size = st.sidebar.number_input(
95
+ "No repeat NGram size", min_value=1, max_value=5, value=3
96
+ )
97
 
98
  if sampling_mode := st.sidebar.selectbox(
99
  "select a Mode", index=0, options=["Beam Search", "Top-k Sampling"]
100
  ):
101
  if sampling_mode == "Beam Search":
102
+ num_beams = st.sidebar.number_input(
103
+ "Num beams", min_value=1, max_value=10, value=4
104
+ )
105
+ length_penalty = st.sidebar.number_input(
106
+ "Length penalty", min_value=0.0, max_value=5.0, value=1.5, step=0.1
107
+ )
108
+ params = {
109
+ "min_length": min_length,
110
+ "max_length": max_length,
111
+ "no_repeat_ngram_size": no_repeat_ngram_size,
112
+ "num_beams": num_beams,
113
+ "early_stopping": True,
114
+ "length_penalty": length_penalty,
115
+ "num_return_sequences": 1,
116
+ }
117
  else:
118
+ top_k = st.sidebar.number_input(
119
+ "Top K", min_value=0, max_value=100, value=50
120
+ )
121
+ top_p = st.sidebar.number_input(
122
+ "Top P", min_value=0.0, max_value=1.0, value=0.9, step=0.05
123
+ )
124
+ temperature = st.sidebar.number_input(
125
+ "Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.05
126
+ )
127
+ params = {
128
+ "min_length": min_length,
129
+ "max_length": max_length,
130
+ "no_repeat_ngram_size": no_repeat_ngram_size,
131
+ "do_sample": True,
132
+ "top_k": top_k,
133
+ "top_p": top_p,
134
+ "temperature": temperature,
135
+ "num_return_sequences": 1,
136
+ }
137
+
138
+ st.sidebar.markdown(
139
+ """For an explanation of the parameters, please to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
140
+ and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
141
+ """
142
+ )
143
 
144
  input_text = st.text_area("Enter a Dutch news text", DEFAULT_TEXT, height=500)
145