Балаганский Никита Николаевич
commited on
Commit
•
01f8fc1
1
Parent(s):
b5beaeb
add num_tokens arg
Browse files
app.py
CHANGED
@@ -146,9 +146,9 @@ def main():
|
|
146 |
act_type = "sigmoid"
|
147 |
else:
|
148 |
label2id = cls_model_config.label2id
|
149 |
-
filtered_label2id = {k: v
|
150 |
-
label_key = st.selectbox(ATTRIBUTE_LABEL[language],
|
151 |
-
target_label_id =
|
152 |
act_type = "softmax"
|
153 |
st.write(WARNING_TEXT[language])
|
154 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
@@ -183,6 +183,8 @@ def main():
|
|
183 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
|
184 |
else:
|
185 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
|
|
|
|
186 |
st.subheader("Generated text:")
|
187 |
|
188 |
def generate():
|
@@ -194,7 +196,8 @@ def main():
|
|
194 |
target_label_id=target_label_id,
|
195 |
entropy_threshold=entropy_threshold,
|
196 |
fp16=fp16,
|
197 |
-
act_type=act_type
|
|
|
198 |
)
|
199 |
|
200 |
st.button("Generate new", on_click=generate())
|
@@ -225,7 +228,8 @@ def inference(
|
|
225 |
alpha: float = 5,
|
226 |
target_label_id: int = 0,
|
227 |
entropy_threshold: float = 0,
|
228 |
-
act_type: str = "sigmoid"
|
|
|
229 |
) -> str:
|
230 |
torch.set_grad_enabled(False)
|
231 |
generator = load_generator(lm_model_name=lm_model_name)
|
@@ -259,7 +263,7 @@ def inference(
|
|
259 |
sequences, tokens = generator.sample_sequences(
|
260 |
num_samples=1,
|
261 |
input_prompt=prompt,
|
262 |
-
max_length=
|
263 |
caif_period=1,
|
264 |
entropy=entropy_threshold,
|
265 |
progress_bar=progress_bar,
|
|
|
146 |
act_type = "sigmoid"
|
147 |
else:
|
148 |
label2id = cls_model_config.label2id
|
149 |
+
filtered_label2id = {k: v for k, v in label2id.items() if "negative" in k.lower()}
|
150 |
+
label_key = st.selectbox(ATTRIBUTE_LABEL[language], filtered_label2id.keys())
|
151 |
+
target_label_id = filtered_label2id[label_key]
|
152 |
act_type = "softmax"
|
153 |
st.write(WARNING_TEXT[language])
|
154 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
|
|
183 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
|
184 |
else:
|
185 |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
186 |
+
num_tokens = st.slider("# tokens to be generated", min_value=5, max_value=40, step=1, value=20)
|
187 |
+
num_tokens = int(num_tokens)
|
188 |
st.subheader("Generated text:")
|
189 |
|
190 |
def generate():
|
|
|
196 |
target_label_id=target_label_id,
|
197 |
entropy_threshold=entropy_threshold,
|
198 |
fp16=fp16,
|
199 |
+
act_type=act_type,
|
200 |
+
num_tokens=num_tokens
|
201 |
)
|
202 |
|
203 |
st.button("Generate new", on_click=generate())
|
|
|
228 |
alpha: float = 5,
|
229 |
target_label_id: int = 0,
|
230 |
entropy_threshold: float = 0,
|
231 |
+
act_type: str = "sigmoid",
|
232 |
+
num_tokens=10,
|
233 |
) -> str:
|
234 |
torch.set_grad_enabled(False)
|
235 |
generator = load_generator(lm_model_name=lm_model_name)
|
|
|
263 |
sequences, tokens = generator.sample_sequences(
|
264 |
num_samples=1,
|
265 |
input_prompt=prompt,
|
266 |
+
max_length=num_tokens,
|
267 |
caif_period=1,
|
268 |
entropy=entropy_threshold,
|
269 |
progress_bar=progress_bar,
|