Балаганский Никита Николаевич commited on
Commit
47f9ff2
1 Parent(s): d320fdd

add entropy_threshold

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -19,18 +19,24 @@ def main():
19
  cls_model_name = st.selectbox(
20
  'Выберите модель классификации',
21
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
22
- 'tinkoff-ai/response-quality-classifier-large')
23
  )
24
  lm_model_name = st.selectbox(
25
  'Выберите языковую модель',
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
  cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
29
- label2id = cls_model_config.label2id
30
- label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
31
- target_label_id = label2id[label_key]
 
 
 
 
 
32
  prompt = st.text_input("Начало текста:", "Привет")
33
- alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
 
34
  auth_token = os.environ.get('TOKEN') or True
35
  with st.spinner('Running inference...'):
36
  text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha)
@@ -55,11 +61,22 @@ def load_sampler(cls_model_name, lm_tokenizer):
55
 
56
  @st.cache
57
  def inference(
58
- lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0
 
 
 
 
 
 
59
  ) -> str:
60
  generator = load_generator(lm_model_name=lm_model_name)
61
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
62
- caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
 
 
 
 
 
63
  generator.set_caif_sampler(caif_sampler)
64
  ordinary_sampler = TopKWithTemperatureSampler()
65
  kwargs = {
@@ -81,7 +98,7 @@ def inference(
81
  input_prompt=prompt,
82
  max_length=20,
83
  caif_period=1,
84
- entropy=None,
85
  **kwargs
86
  )
87
  print(f"Output for prompt: {sequences}")
 
19
  cls_model_name = st.selectbox(
20
  'Выберите модель классификации',
21
  ('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
22
+ 'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier")
23
  )
24
  lm_model_name = st.selectbox(
25
  'Выберите языковую модель',
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
  cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
29
+ if cls_model_config.problem_type == "multi_label_classification":
30
+ label2id = cls_model_config.label2id
31
+ label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
32
+ target_label_id = label2id[label_key]
33
+ else:
34
+ label2id = cls_model_config.label2id
35
+ label_key = st.selectbox("Веберите нужный атрибут текста", list(label2id.keys())[-1])
36
+ target_label_id = 0
37
  prompt = st.text_input("Начало текста:", "Привет")
38
+ alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1, value=0)
39
+ entropy_threshold = st.slider("Entropy Threshold:", min_value=0., max_value=5., step=.1, value=0.)
40
  auth_token = os.environ.get('TOKEN') or True
41
  with st.spinner('Running inference...'):
42
  text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha)
 
61
 
62
  @st.cache
63
  def inference(
64
+ lm_model_name: str,
65
+ cls_model_name: str,
66
+ prompt: str,
67
+ fp16: bool = True,
68
+ alpha: float = 5,
69
+ target_label_id: int = 0,
70
+ entropy_threshold: float = 0
71
  ) -> str:
72
  generator = load_generator(lm_model_name=lm_model_name)
73
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
74
+ if alpha != 0:
75
+ caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
76
+ else:
77
+ caif_sampler = None
78
+ if entropy_threshold < 0.05:
79
+ entropy_threshold = None
80
  generator.set_caif_sampler(caif_sampler)
81
  ordinary_sampler = TopKWithTemperatureSampler()
82
  kwargs = {
 
98
  input_prompt=prompt,
99
  max_length=20,
100
  caif_period=1,
101
+ entropy=entropy_threshold,
102
  **kwargs
103
  )
104
  print(f"Output for prompt: {sequences}")