akiFQC commited on
Commit
34aa338
1 Parent(s): f9018ff
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -9,12 +9,13 @@ model = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
 
11
  class DialogGPT:
12
- def __init__(self, tokenizer, model, n_candidate=4, param_lambda=0.1):
13
  self.tokenizer = tokenizer
14
  self.model = model
15
  self.model.eval()
16
  self.n_candidate = n_candidate
17
  self.param_lambda = param_lambda
 
18
 
19
  def _calc_single_scores(self, token_ids):
20
  with torch.inference_mode():
@@ -33,7 +34,7 @@ class DialogGPT:
33
  # log_likelihood (b, l)
34
  log_likelihood = logit_at_target
35
  log_likelihood.masked_fill_(mask_at_pad, 0.0)
36
- log_likelihood_per_candidate = log_likelihood.sum(dim=1)
37
  # normalize by length
38
  # log_likelihood_per_candidate = log_likelihood_per_candidate / (candidate_token_ids.shape[1] - mask_at_pad.sum(dim=1))
39
  return log_likelihood_per_candidate
@@ -85,7 +86,7 @@ class DialogGPT:
85
  max_time=10,
86
  num_return_sequences=self.n_candidate,
87
  max_length=512,
88
- min_length=2,
89
  forced_eos_token_id=self.tokenizer.pad_token_id,
90
  return_dict_in_generate=True,
91
  output_scores=True,
 
9
 
10
 
11
  class DialogGPT:
12
+ def __init__(self, tokenizer, model, n_candidate=4, param_lambda=0.10):
13
  self.tokenizer = tokenizer
14
  self.model = model
15
  self.model.eval()
16
  self.n_candidate = n_candidate
17
  self.param_lambda = param_lambda
18
+ self.param_gamma: int = 2
19
 
20
  def _calc_single_scores(self, token_ids):
21
  with torch.inference_mode():
 
34
  # log_likelihood (b, l)
35
  log_likelihood = logit_at_target
36
  log_likelihood.masked_fill_(mask_at_pad, 0.0)
37
+ log_likelihood_per_candidate = log_likelihood[:, self.param_gamma:].sum(dim=1)
38
  # normalize by length
39
  # log_likelihood_per_candidate = log_likelihood_per_candidate / (candidate_token_ids.shape[1] - mask_at_pad.sum(dim=1))
40
  return log_likelihood_per_candidate
 
86
  max_time=10,
87
  num_return_sequences=self.n_candidate,
88
  max_length=512,
89
+ min_length=4,
90
  forced_eos_token_id=self.tokenizer.pad_token_id,
91
  return_dict_in_generate=True,
92
  output_scores=True,