Балаганский Никита Николаевич commited on
Commit
d320fdd
1 Parent(s): 6331a08

add target_label_id

Browse files
Files changed (2) hide show
  1. app.py +8 -2
  2. sampling.py +8 -5
app.py CHANGED
@@ -25,6 +25,10 @@ def main():
25
  'Выберите языковую модель',
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
 
 
 
 
28
  prompt = st.text_input("Начало текста:", "Привет")
29
  alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
30
  auth_token = os.environ.get('TOKEN') or True
@@ -50,7 +54,9 @@ def load_sampler(cls_model_name, lm_tokenizer):
50
 
51
 
52
  @st.cache
53
- def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5) -> str:
 
 
54
  generator = load_generator(lm_model_name=lm_model_name)
55
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
56
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
@@ -61,6 +67,7 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
61
  "temperature": 1.0,
62
  "top_k_classifier": 100,
63
  "classifier_weight": alpha,
 
64
  }
65
  generator.set_ordinary_sampler(ordinary_sampler)
66
  if device == "cpu":
@@ -74,7 +81,6 @@ def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool =
74
  input_prompt=prompt,
75
  max_length=20,
76
  caif_period=1,
77
- caif_tokens_num=100,
78
  entropy=None,
79
  **kwargs
80
  )
 
25
  'Выберите языковую модель',
26
  ('sberbank-ai/rugpt3small_based_on_gpt2',)
27
  )
28
+ cls_model_config = transformers.AutoConfig.from_pretrained(cls_model_name)
29
+ label2id = cls_model_config.label2id
30
+ label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys())
31
+ target_label_id = label2id[label_key]
32
  prompt = st.text_input("Начало текста:", "Привет")
33
  alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1)
34
  auth_token = os.environ.get('TOKEN') or True
 
54
 
55
 
56
  @st.cache
57
+ def inference(
58
+ lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0
59
+ ) -> str:
60
  generator = load_generator(lm_model_name=lm_model_name)
61
  lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
62
  caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
 
67
  "temperature": 1.0,
68
  "top_k_classifier": 100,
69
  "classifier_weight": alpha,
70
+ "target_cls_id": target_label_id
71
  }
72
  generator.set_ordinary_sampler(ordinary_sampler)
73
  if device == "cpu":
 
81
  input_prompt=prompt,
82
  max_length=20,
83
  caif_period=1,
 
84
  entropy=None,
85
  **kwargs
86
  )
sampling.py CHANGED
@@ -49,10 +49,11 @@ class CAIFSampler:
49
  top_k_classifier,
50
  classifier_weight,
51
  caif_tokens_num=None,
 
52
  **kwargs
53
  ):
 
54
  next_token_logits = output_logis[:, -1]
55
-
56
  next_token_log_probs = F.log_softmax(
57
  next_token_logits, dim=-1
58
  )
@@ -63,7 +64,8 @@ class CAIFSampler:
63
  temperature,
64
  top_k_classifier,
65
  classifier_weight,
66
- caif_tokens_num=caif_tokens_num
 
67
  )
68
  topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
69
  next_tokens = sample_from_values(
@@ -80,6 +82,7 @@ class CAIFSampler:
80
  temperature,
81
  top_k_classifier,
82
  classifier_weight,
 
83
  caif_tokens_num=None
84
  ):
85
 
@@ -109,7 +112,7 @@ class CAIFSampler:
109
  )
110
  else:
111
  classifier_log_probs = self.get_classifier_log_probs(
112
- classifier_input, caif_tokens_num=caif_tokens_num
113
  ).view(-1, top_k_classifier)
114
 
115
  next_token_probs = torch.exp(
@@ -118,7 +121,7 @@ class CAIFSampler:
118
  )
119
  return next_token_probs, top_next_token_log_probs[1]
120
 
121
- def get_classifier_log_probs(self, input, caif_tokens_num=None):
122
  input_ids = self.classifier_tokenizer(
123
  input, padding=True, return_tensors="pt"
124
  ).to(self.device)
@@ -128,7 +131,7 @@ class CAIFSampler:
128
  input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
129
  if "token_type_ids" in input_ids.keys():
130
  input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
131
- logits = self.classifier_model(**input_ids).logits[:, 0].squeeze(-1)
132
  return torch.log(torch.sigmoid(logits))
133
 
134
  def get_classifier_probs(self, input, caif_tokens_num=None):
 
49
  top_k_classifier,
50
  classifier_weight,
51
  caif_tokens_num=None,
52
+ act_type: str = "softmax",
53
  **kwargs
54
  ):
55
+ target_cls_id = kwargs["target_cls_id"]
56
  next_token_logits = output_logis[:, -1]
 
57
  next_token_log_probs = F.log_softmax(
58
  next_token_logits, dim=-1
59
  )
 
64
  temperature,
65
  top_k_classifier,
66
  classifier_weight,
67
+ caif_tokens_num=caif_tokens_num,
68
+ target_cls_id=target_cls_id
69
  )
70
  topk_probs = next_token_unnormalized_probs.topk(top_k, -1)
71
  next_tokens = sample_from_values(
 
82
  temperature,
83
  top_k_classifier,
84
  classifier_weight,
85
+ target_cls_id: int = 0,
86
  caif_tokens_num=None
87
  ):
88
 
 
112
  )
113
  else:
114
  classifier_log_probs = self.get_classifier_log_probs(
115
+ classifier_input, caif_tokens_num=caif_tokens_num, target_cls_id=target_cls_id,
116
  ).view(-1, top_k_classifier)
117
 
118
  next_token_probs = torch.exp(
 
121
  )
122
  return next_token_probs, top_next_token_log_probs[1]
123
 
124
+ def get_classifier_log_probs(self, input, caif_tokens_num=None, target_cls_id: int = 0):
125
  input_ids = self.classifier_tokenizer(
126
  input, padding=True, return_tensors="pt"
127
  ).to(self.device)
 
131
  input_ids["attention_mask"] = input_ids["attention_mask"][:, -caif_tokens_num:]
132
  if "token_type_ids" in input_ids.keys():
133
  input_ids["token_type_ids"] = input_ids["token_type_ids"][:, -caif_tokens_num:]
134
+ logits = self.classifier_model(**input_ids).logits[:, target_cls_id].squeeze(-1)
135
  return torch.log(torch.sigmoid(logits))
136
 
137
  def get_classifier_probs(self, input, caif_tokens_num=None):