Балаганский Никита Николаевич
commited on
Commit
•
47f9ff2
1
Parent(s):
d320fdd
add entropy_threshold
Browse files
app.py
CHANGED
@@ -19,18 +19,24 @@ def main():
|
|
19 |
cls_model_name = st.selectbox(
|
20 |
'Выберите модель классификации',
|
21 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
|
22 |
-
'tinkoff-ai/response-quality-classifier-large')
|
23 |
)
|
24 |
lm_model_name = st.selectbox(
|
25 |
'Выберите языковую модель',
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
28 |
cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
prompt = st.text_input("Начало текста:", "Привет")
|
33 |
-
alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
|
|
|
34 |
auth_token = os.environ.get('TOKEN') or True
|
35 |
with st.spinner('Running inference...'):
|
36 |
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha)
|
@@ -55,11 +61,22 @@ def load_sampler(cls_model_name, lm_tokenizer):
|
|
55 |
|
56 |
@st.cache
|
57 |
def inference(
|
58 |
-
lm_model_name: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
) -> str:
|
60 |
generator = load_generator(lm_model_name=lm_model_name)
|
61 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
generator.set_caif_sampler(caif_sampler)
|
64 |
ordinary_sampler = TopKWithTemperatureSampler()
|
65 |
kwargs = {
|
@@ -81,7 +98,7 @@ def inference(
|
|
81 |
input_prompt=prompt,
|
82 |
max_length=20,
|
83 |
caif_period=1,
|
84 |
-
entropy=
|
85 |
**kwargs
|
86 |
)
|
87 |
print(f"Output for prompt: {sequences}")
|
|
|
19 |
cls_model_name = st.selectbox(
|
20 |
'Выберите модель классификации',
|
21 |
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
|
22 |
+
'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier")
|
23 |
)
|
24 |
lm_model_name = st.selectbox(
|
25 |
'Выберите языковую модель',
|
26 |
('sberbank-ai/rugpt3small_based_on_gpt2',)
|
27 |
)
|
28 |
cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
|
29 |
+
if cls_model_config.problem_type == "multi_label_classification":
|
30 |
+
label2id = cls_model_config.label2id
|
31 |
+
label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
|
32 |
+
target_label_id = label2id[label_key]
|
33 |
+
else:
|
34 |
+
label2id = cls_model_config.label2id
|
35 |
+
label_key = st.selectbox("Веберите нужный атрибут текста", list(label2id.keys())[-1])
|
36 |
+
target_label_id = 0
|
37 |
prompt = st.text_input("Начало текста:", "Привет")
|
38 |
+
alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1, value=0)
|
39 |
+
entropy_threshold = st.slider("Entropy Threshold:", min_value=0., max_value=5., step=.1, value=0.)
|
40 |
auth_token = os.environ.get('TOKEN') or True
|
41 |
with st.spinner('Running inference...'):
|
42 |
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha)
|
|
|
61 |
|
62 |
@st.cache
|
63 |
def inference(
|
64 |
+
lm_model_name: str,
|
65 |
+
cls_model_name: str,
|
66 |
+
prompt: str,
|
67 |
+
fp16: bool = True,
|
68 |
+
alpha: float = 5,
|
69 |
+
target_label_id: int = 0,
|
70 |
+
entropy_threshold: float = 0
|
71 |
) -> str:
|
72 |
generator = load_generator(lm_model_name=lm_model_name)
|
73 |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
|
74 |
+
if alpha != 0:
|
75 |
+
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
|
76 |
+
else:
|
77 |
+
caif_sampler = None
|
78 |
+
if entropy_threshold < 0.05:
|
79 |
+
entropy_threshold = None
|
80 |
generator.set_caif_sampler(caif_sampler)
|
81 |
ordinary_sampler = TopKWithTemperatureSampler()
|
82 |
kwargs = {
|
|
|
98 |
input_prompt=prompt,
|
99 |
max_length=20,
|
100 |
caif_period=1,
|
101 |
+
entropy=entropy_threshold,
|
102 |
**kwargs
|
103 |
)
|
104 |
print(f"Output for prompt: {sequences}")
|