Балаганский Никита Николаевич commited on
Commit
c3ce809
1 Parent(s): 53cf441

add languages

Browse files
Files changed (1) hide show
  1. app.py +56 -15
app.py CHANGED
@@ -15,35 +15,76 @@ from generator import Generator
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def main():
20
  st.header("CAIF")
 
21
  cls_model_name = st.selectbox(
22
- 'Выберите модель классификации',
23
- (
24
- 'tinkoff-ai/response-quality-classifier-tiny',
25
- 'tinkoff-ai/response-quality-classifier-base',
26
- 'tinkoff-ai/response-quality-classifier-large',
27
- "SkolkovoInstitute/roberta_toxicity_classifier"
28
- )
29
  )
30
  lm_model_name = st.selectbox(
31
- 'Выберите языковую модель',
32
- ('sberbank-ai/rugpt3small_based_on_gpt2',)
33
  )
34
  cls_model_config = AutoConfig.from_pretrained(cls_model_name)
35
  if cls_model_config.problem_type == "multi_label_classification":
36
  label2id = cls_model_config.label2id
37
- label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
38
  target_label_id = label2id[label_key]
39
  else:
40
  label2id = cls_model_config.label2id
41
  print(list(label2id.keys()))
42
- label_key = st.selectbox("Веберите нужный атрибут текста", [list(label2id.keys())[-1]])
43
- target_label_id = 1
44
- prompt = st.text_input("Начало текста:", "Привет")
45
- alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1, value=0)
46
- entropy_threshold = st.slider("Entropy Threshold:", min_value=0., max_value=5., step=.1, value=0.)
47
  auth_token = os.environ.get('TOKEN') or True
48
  with st.spinner('Running inference...'):
49
  text = inference(
 
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ ATTRIBUTE_MODELS = {
19
+ "Russian": (
20
+ "cointegrated/rubert-tiny-toxicity",
21
+ 'tinkoff-ai/response-quality-classifier-tiny',
22
+ 'tinkoff-ai/response-quality-classifier-base',
23
+ 'tinkoff-ai/response-quality-classifier-large',
24
+ "SkolkovoInstitute/roberta_toxicity_classifier",
25
+ "SkolkovoInstitute/russian_toxicity_classifier"
26
+ ),
27
+ "English": (
28
+ "unitary/toxic-bert",
29
+ )
30
+ }
31
+
32
+ LANGUAGE_MODELS = {
33
+ "Russian": ('sberbank-ai/rugpt3small_based_on_gpt2',),
34
+ "Eanglish": ("distilgpt2")
35
+ }
36
+
37
+ ATTRIBUTE_MODEL_LABEL = {
38
+ "Russian": 'Выберите модель классификации',
39
+ "English": "Choose attribute model"
40
+ }
41
+
42
+ LM_LABEL = {
43
+ "English": "Choose language model",
44
+ "Russian": "Выберите языковую модель"
45
+ }
46
+
47
+ ATTRIBUTE_LABEL = {
48
+ "Russian": "Веберите нужный атрибут текста",
49
+ "English": "Choose desired attribute",
50
+ }
51
+
52
+ TEXT_PROMPT_LABEL = {
53
+ "English": "Text prompt",
54
+ "Russian": "Начало текста"
55
+ }
56
+
57
+ PROMPT_EXAMPLE = {
58
+ "English": "Hello, today I",
59
+ "Russian": "Привет, сегодня я"
60
+ }
61
+
62
 
63
  def main():
64
  st.header("CAIF")
65
+ language = st.selectbox("Language", ("English", "Russian"))
66
  cls_model_name = st.selectbox(
67
+ ATTRIBUTE_MODEL_LABEL[language],
68
+ ATTRIBUTE_MODELS[language]
69
+
 
 
 
 
70
  )
71
  lm_model_name = st.selectbox(
72
+ LM_LABEL[language],
73
+ LANGUAGE_MODELS[language]
74
  )
75
  cls_model_config = AutoConfig.from_pretrained(cls_model_name)
76
  if cls_model_config.problem_type == "multi_label_classification":
77
  label2id = cls_model_config.label2id
78
+ label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
79
  target_label_id = label2id[label_key]
80
  else:
81
  label2id = cls_model_config.label2id
82
  print(list(label2id.keys()))
83
+ label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
84
+ target_label_id = 0
85
+ prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
86
+ alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
87
+ entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=0.)
88
  auth_token = os.environ.get('TOKEN') or True
89
  with st.spinner('Running inference...'):
90
  text = inference(