Spaces:
Build error
Build error
add repetition penalty
Browse files- app/app.py +22 -4
app/app.py
CHANGED
@@ -110,12 +110,16 @@ def get_generator(model_name: str):
|
|
110 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
111 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
112 |
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
113 |
-
temperature: float = 1.0, max_time: float = 120.0, seed=42):
|
114 |
# st.write("Cache miss: process")
|
115 |
set_seed(seed)
|
|
|
|
|
|
|
|
|
116 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
117 |
top_k=top_k, top_p=top_p, temperature=temperature,
|
118 |
-
max_time=max_time)
|
119 |
return result
|
120 |
|
121 |
|
@@ -164,7 +168,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
164 |
"Temperature",
|
165 |
value=0.9,
|
166 |
min_value=0.0,
|
167 |
-
max_value=
|
168 |
)
|
169 |
|
170 |
do_sample = st.sidebar.checkbox(
|
@@ -194,6 +198,20 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
194 |
help="The number used to initialize a pseudorandom number generator"
|
195 |
)
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
for group_name in MODELS:
|
198 |
if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
199 |
MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
|
@@ -206,7 +224,7 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
206 |
# text_generator = MODELS[model]["text_generator"]
|
207 |
result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
|
208 |
temperature=temperature, do_sample=do_sample,
|
209 |
-
top_k=int(top_k), top_p=float(top_p), seed=seed)
|
210 |
time_end = time.time()
|
211 |
time_diff = time_end-time_start
|
212 |
result = result[0]["generated_text"]
|
|
|
110 |
# Disable the st.cache for this function due to issue on newer version of streamlit
|
111 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
112 |
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
113 |
+
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0):
|
114 |
# st.write("Cache miss: process")
|
115 |
set_seed(seed)
|
116 |
+
if repetition_penalty == 0.0:
|
117 |
+
min_penalty = 1.05
|
118 |
+
max_penalty = 1.5
|
119 |
+
repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
|
120 |
result = text_generator(text, max_length=max_length, do_sample=do_sample,
|
121 |
top_k=top_k, top_p=top_p, temperature=temperature,
|
122 |
+
max_time=max_time, repetition_penalty=repetition_penalty)
|
123 |
return result
|
124 |
|
125 |
|
|
|
168 |
"Temperature",
|
169 |
value=0.9,
|
170 |
min_value=0.0,
|
171 |
+
max_value=2.0
|
172 |
)
|
173 |
|
174 |
do_sample = st.sidebar.checkbox(
|
|
|
198 |
help="The number used to initialize a pseudorandom number generator"
|
199 |
)
|
200 |
|
201 |
+
repetition_penalty = 0.0
|
202 |
+
automatic_repetition_penalty = st.sidebar.checkbox(
|
203 |
+
"Automatic Repetition Penalty",
|
204 |
+
value=True
|
205 |
+
)
|
206 |
+
|
207 |
+
if not automatic_repetition_penalty:
|
208 |
+
repetition_penalty = st.sidebar.slider(
|
209 |
+
"Repetition Penalty",
|
210 |
+
value=1.0,
|
211 |
+
min_value=1.0,
|
212 |
+
max_value=2.0
|
213 |
+
)
|
214 |
+
|
215 |
for group_name in MODELS:
|
216 |
if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
217 |
MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
|
|
|
224 |
# text_generator = MODELS[model]["text_generator"]
|
225 |
result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
|
226 |
temperature=temperature, do_sample=do_sample,
|
227 |
+
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
228 |
time_end = time.time()
|
229 |
time_diff = time_end-time_start
|
230 |
result = result[0]["generated_text"]
|