Балаганский Никита Николаевич
commited on
Commit
•
b5beaeb
1
Parent(s):
3fc6ce8
remove russian language
Browse files
app.py
CHANGED
@@ -123,7 +123,7 @@ def main():
|
|
123 |
"template": "plotly_white",
|
124 |
})
|
125 |
|
126 |
-
language =
|
127 |
cls_model_name = st.selectbox(
|
128 |
ATTRIBUTE_MODEL_LABEL[language],
|
129 |
ATTRIBUTE_MODELS[language]
|
@@ -136,15 +136,7 @@ def main():
|
|
136 |
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
|
137 |
if cls_model_config.problem_type == "multi_label_classification":
|
138 |
label2id = cls_model_config.label2id
|
139 |
-
|
140 |
-
idx = 0
|
141 |
-
for i, k in enumerate(label2id.keys()):
|
142 |
-
if k == 'threat':
|
143 |
-
idx = i
|
144 |
-
|
145 |
-
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys(), index=idx)
|
146 |
-
else:
|
147 |
-
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
148 |
target_label_id = label2id[label_key]
|
149 |
act_type = "sigmoid"
|
150 |
elif cls_model_config.problem_type == "single_label_classification":
|
@@ -154,20 +146,17 @@ def main():
|
|
154 |
act_type = "sigmoid"
|
155 |
else:
|
156 |
label2id = cls_model_config.label2id
|
|
|
157 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
158 |
target_label_id = label2id[label_key]
|
159 |
act_type = "softmax"
|
160 |
st.write(WARNING_TEXT[language])
|
161 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
162 |
-
if "sst" in cls_model_name:
|
163 |
-
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
|
164 |
-
else:
|
165 |
-
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
166 |
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
|
167 |
if act_type == "softmax":
|
168 |
alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0)
|
169 |
else:
|
170 |
-
alpha = st.slider("α", min_value=-
|
171 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.)
|
172 |
plot_idx = np.argmin(np.abs(entropy_threshold - x_s))
|
173 |
scatter_tip = go.Scatter({
|
@@ -190,6 +179,10 @@ def main():
|
|
190 |
auth_token = os.environ.get('TOKEN') or True
|
191 |
fp16 = st.checkbox("FP16", value=True)
|
192 |
st.session_state["generated_text"] = None
|
|
|
|
|
|
|
|
|
193 |
st.subheader("Generated text:")
|
194 |
|
195 |
def generate():
|
|
|
123 |
"template": "plotly_white",
|
124 |
})
|
125 |
|
126 |
+
language = "English"
|
127 |
cls_model_name = st.selectbox(
|
128 |
ATTRIBUTE_MODEL_LABEL[language],
|
129 |
ATTRIBUTE_MODELS[language]
|
|
|
136 |
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
|
137 |
if cls_model_config.problem_type == "multi_label_classification":
|
138 |
label2id = cls_model_config.label2id
|
139 |
+
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
target_label_id = label2id[label_key]
|
141 |
act_type = "sigmoid"
|
142 |
elif cls_model_config.problem_type == "single_label_classification":
|
|
|
146 |
act_type = "sigmoid"
|
147 |
else:
|
148 |
label2id = cls_model_config.label2id
|
149 |
+
filtered_label2id = {k: v for k, v in label2id.items() if "negative" in k}
|
150 |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
|
151 |
target_label_id = label2id[label_key]
|
152 |
act_type = "softmax"
|
153 |
st.write(WARNING_TEXT[language])
|
154 |
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
|
|
|
|
|
|
|
|
|
155 |
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
|
156 |
if act_type == "softmax":
|
157 |
alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0)
|
158 |
else:
|
159 |
+
alpha = st.slider("α", min_value=-5, max_value=5 if show_pos_alpha else 0, step=1, value=0)
|
160 |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.)
|
161 |
plot_idx = np.argmin(np.abs(entropy_threshold - x_s))
|
162 |
scatter_tip = go.Scatter({
|
|
|
179 |
auth_token = os.environ.get('TOKEN') or True
|
180 |
fp16 = st.checkbox("FP16", value=True)
|
181 |
st.session_state["generated_text"] = None
|
182 |
+
if "sst" in cls_model_name:
|
183 |
+
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
|
184 |
+
else:
|
185 |
+
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
|
186 |
st.subheader("Generated text:")
|
187 |
|
188 |
def generate():
|