cahya commited on
Commit
f39f10a
1 Parent(s): b5a7127

add repetition penalty

Browse files
Files changed (1) hide show
  1. app/app.py +22 -4
app/app.py CHANGED
@@ -110,12 +110,16 @@ def get_generator(model_name: str):
110
  # Disable the st.cache for this function due to issue on newer version of streamlit
111
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
112
  def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
113
- temperature: float = 1.0, max_time: float = 120.0, seed=42):
114
  # st.write("Cache miss: process")
115
  set_seed(seed)
 
 
 
 
116
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
117
  top_k=top_k, top_p=top_p, temperature=temperature,
118
- max_time=max_time)
119
  return result
120
 
121
 
@@ -164,7 +168,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
164
  "Temperature",
165
  value=0.9,
166
  min_value=0.0,
167
- max_value=5.0
168
  )
169
 
170
  do_sample = st.sidebar.checkbox(
@@ -194,6 +198,20 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
194
  help="The number used to initialize a pseudorandom number generator"
195
  )
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  for group_name in MODELS:
198
  if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
199
  MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
@@ -206,7 +224,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
206
  # text_generator = MODELS[model]["text_generator"]
207
  result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
208
  temperature=temperature, do_sample=do_sample,
209
- top_k=int(top_k), top_p=float(top_p), seed=seed)
210
  time_end = time.time()
211
  time_diff = time_end-time_start
212
  result = result[0]["generated_text"]
 
110
  # Disable the st.cache for this function due to issue on newer version of streamlit
111
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
112
  def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
113
+ temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0):
114
  # st.write("Cache miss: process")
115
  set_seed(seed)
116
+ if repetition_penalty == 0.0:
117
+ min_penalty = 1.05
118
+ max_penalty = 1.5
119
+ repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
120
  result = text_generator(text, max_length=max_length, do_sample=do_sample,
121
  top_k=top_k, top_p=top_p, temperature=temperature,
122
+ max_time=max_time, repetition_penalty=repetition_penalty)
123
  return result
124
 
125
 
 
168
  "Temperature",
169
  value=0.9,
170
  min_value=0.0,
171
+ max_value=2.0
172
  )
173
 
174
  do_sample = st.sidebar.checkbox(
 
198
  help="The number used to initialize a pseudorandom number generator"
199
  )
200
 
201
+ repetition_penalty = 0.0
202
+ automatic_repetition_penalty = st.sidebar.checkbox(
203
+ "Automatic Repetition Penalty",
204
+ value=True
205
+ )
206
+
207
+ if not automatic_repetition_penalty:
208
+ repetition_penalty = st.sidebar.slider(
209
+ "Repetition Penalty",
210
+ value=1.0,
211
+ min_value=1.0,
212
+ max_value=2.0
213
+ )
214
+
215
  for group_name in MODELS:
216
  if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
217
  MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
 
224
  # text_generator = MODELS[model]["text_generator"]
225
  result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
226
  temperature=temperature, do_sample=do_sample,
227
+ top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
228
  time_end = time.time()
229
  time_diff = time_end-time_start
230
  result = result[0]["generated_text"]