Балаганский Никита Николаевич commited on
Commit
f57bdfa
1 Parent(s): 895c44e
Files changed (2) hide show
  1. app.py +11 -7
  2. sampling.py +17 -7
app.py CHANGED
@@ -24,14 +24,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
24
  ATTRIBUTE_MODELS = {
25
  "Russian": (
26
  "cointegrated/rubert-tiny-toxicity",
27
- 'tinkoff-ai/response-quality-classifier-tiny',
28
- 'tinkoff-ai/response-quality-classifier-base',
29
- 'tinkoff-ai/response-quality-classifier-large',
30
- "SkolkovoInstitute/roberta_toxicity_classifier",
31
  "SkolkovoInstitute/russian_toxicity_classifier"
32
  ),
33
  "English": (
34
  "unitary/toxic-bert",
 
35
  )
36
  }
37
 
@@ -72,7 +69,7 @@ WARNING_TEXT = {
72
  "English": """
73
  **Warning!**
74
 
75
- If you are clicking checkbox bellow positive""" + r"$\alpha$" + """ values for CAIF sampling become available.
76
  It means that language model will be forced to produce toxic or/and abusive text.
77
  This space is only a demonstration of our method for controllable text generation
78
  and we are not responsible for the content produced by this method.
@@ -128,11 +125,17 @@ def main():
128
  label2id = cls_model_config.label2id
129
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
130
  target_label_id = label2id[label_key]
131
- else:
 
132
  label2id = cls_model_config.label2id
133
- print(list(label2id.keys()))
134
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
135
  target_label_id = 1
 
 
 
 
 
 
136
  st.write(WARNING_TEXT[language])
137
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
138
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
@@ -168,6 +171,7 @@ def main():
168
  target_label_id=target_label_id,
169
  entropy_threshold=entropy_threshold,
170
  fp16=fp16,
 
171
  )
172
  st.subheader("Generated text:")
173
  st.write(text)
 
24
  ATTRIBUTE_MODELS = {
25
  "Russian": (
26
  "cointegrated/rubert-tiny-toxicity",
 
 
 
 
27
  "SkolkovoInstitute/russian_toxicity_classifier"
28
  ),
29
  "English": (
30
  "unitary/toxic-bert",
31
+ "distilbert-base-uncased-finetuned-sst-2-english"
32
  )
33
  }
34
 
 
69
  "English": """
70
  **Warning!**
71
 
72
+ If you are clicking checkbox bellow positive """ + r"$\alpha$" + """ values for CAIF sampling become available.
73
  It means that language model will be forced to produce toxic or/and abusive text.
74
  This space is only a demonstration of our method for controllable text generation
75
  and we are not responsible for the content produced by this method.
 
125
  label2id = cls_model_config.label2id
126
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
127
  target_label_id = label2id[label_key]
128
+ act_type = "sigmoid"
129
+ elif cls_model_config.problem_type == "single_label_classification":
130
  label2id = cls_model_config.label2id
 
131
  label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
132
  target_label_id = 1
133
+ act_type = "sigmoid"
134
+ else:
135
+ label2id = cls_model_config.label2id
136
+ label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
137
+ target_label_id = label2id[label_key]
138
+ act_type = "softmax"
139
  st.write(WARNING_TEXT[language])
140
  show_pos_alpha = st.checkbox("Show positive alphas", value=False)
141
  prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
 
171
  target_label_id=target_label_id,
172
  entropy_threshold=entropy_threshold,
173
  fp16=fp16,
174
+ act_type=act_type
175
  )
176
  st.subheader("Generated text:")
177
  st.write(text)
sampling.py CHANGED
@@ -53,6 +53,7 @@ class CAIFSampler:
53
  **kwargs
54
  ):
55
  target_cls_id = kwargs["target_cls_id"]
 
56
  next_token_logits = output_logis[:, -1]
57
  next_token_log_probs = F.log_softmax(
58
  next_token_logits, dim=-1
@@ -83,6 +84,7 @@ class CAIFSampler:
83
  top_k_classifier,
84
  classifier_weight,
85
  target_cls_id: int = 0,
 
86
  caif_tokens_num=None
87
  ):
88
 
@@ -107,12 +109,15 @@ class CAIFSampler:
107
  if self.invert_cls_probs:
108
  classifier_log_probs = torch.log(
109
  1 - self.get_classifier_probs(
110
- classifier_input, caif_tokens_num=caif_tokens_num
111
  ).view(-1, top_k_classifier)
112
  )
113
  else:
114
  classifier_log_probs = self.get_classifier_log_probs(
115
- classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id,
 
 
 
116
  ).view(-1, top_k_classifier)
117
 
118
  next_token_probs = torch.exp(
@@ -121,7 +126,7 @@ class CAIFSampler:
121
  )
122
  return next_token_probs, top_next_token_log_probs[1]
123
 
124
- def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
125
  input_ids = self.classifier_tokenizer(
126
  input, padding=True, return_tensors="pt"
127
  ).to(self.device)
@@ -131,10 +136,15 @@ class CAIFSampler:
131
  input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
132
  if "token_type_ids" in input_ids.keys():
133
  input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
134
- logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
135
- return torch.log(torch.sigmoid(logits))
136
 
137
- def get_classifier_probs(self, input, caif_tokens_num=None):
 
 
 
 
 
 
 
138
  input_ids = self.classifier_tokenizer(
139
  input, padding=True, return_tensors="pt"
140
  ).to(self.device)
@@ -142,5 +152,5 @@ class CAIFSampler:
142
  input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
143
  if "attention_mask" in input_ids.keys():
144
  input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
145
- logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
146
  return torch.sigmoid(logits)
 
53
  **kwargs
54
  ):
55
  target_cls_id = kwargs["target_cls_id"]
56
+ act_type = kwargs["act_type"]
57
  next_token_logits = output_logis[:, -1]
58
  next_token_log_probs = F.log_softmax(
59
  next_token_logits, dim=-1
 
84
  top_k_classifier,
85
  classifier_weight,
86
  target_cls_id: int = 0,
87
+ act_type: str = "sigmoid",
88
  caif_tokens_num=None
89
  ):
90
 
 
109
  if self.invert_cls_probs:
110
  classifier_log_probs = torch.log(
111
  1 - self.get_classifier_probs(
112
+ classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id
113
  ).view(-1, top_k_classifier)
114
  )
115
  else:
116
  classifier_log_probs = self.get_classifier_log_probs(
117
+ classifier_input,
118
+ caif_tokens_num=caif_tokens_num,
119
+ target_cls_id=target_cls_id,
120
+ act_type=act_type,
121
  ).view(-1, top_k_classifier)
122
 
123
  next_token_probs = torch.exp(
 
126
  )
127
  return next_token_probs, top_next_token_log_probs[1]
128
 
129
+ def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0, act_type: str = "sigmoid"):
130
  input_ids = self.classifier_tokenizer(
131
  input, padding=True, return_tensors="pt"
132
  ).to(self.device)
 
136
  input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
137
  if "token_type_ids" in input_ids.keys():
138
  input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
 
 
139
 
140
+ if act_type == "sigmoid":
141
+ logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
142
+ return F.logsigmoid(logits)
143
+ if act_type == "softmax":
144
+ logits = F.log_softmax(self.classifier_model(**input_ids).logits)[:, target_cls_id].squeeze(-1)
145
+ return logits
146
+
147
+ def get_classifier_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
148
  input_ids = self.classifier_tokenizer(
149
  input, padding=True, return_tensors="pt"
150
  ).to(self.device)
 
152
  input_ids["input_ids"] = input_ids["input_ids"][-caif_tokens_num:]
153
  if "attention_mask" in input_ids.keys():
154
  input_ids["attention_mask"] = input_ids["attention_mask"][-caif_tokens_num:]
155
+ logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
156
  return torch.sigmoid(logits)