AhmedSSabir commited on
Commit
78464e7
1 Parent(s): c505acd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -25
app.py CHANGED
@@ -59,35 +59,63 @@ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
59
  #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
60
 
61
 
62
- def cloze_prob(text):
63
-
64
- whole_text_encoding = tokenizer.encode(text)
65
- text_list = text.split()
66
- stem = ' '.join(text_list[:-1])
67
- stem_encoding = tokenizer.encode(stem)
68
- cw_encoding = whole_text_encoding[len(stem_encoding):]
69
- tokens_tensor = torch.tensor([whole_text_encoding])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with torch.no_grad():
72
- outputs = model(tokens_tensor)
73
- predictions = outputs[0]
74
-
75
- logprobs = []
76
- start = -1-len(cw_encoding)
77
- for j in range(start,-1,1):
78
- raw_output = []
79
- for i in predictions[-1][j]:
80
- raw_output.append(i.item())
81
 
82
- logprobs.append(np.log(softmax(raw_output)))
83
 
84
 
85
- conditional_probs = []
86
- for cw,prob in zip(cw_encoding,logprobs):
87
- conditional_probs.append(prob[cw])
88
 
89
 
90
- return np.exp(np.sum(conditional_probs))
91
 
92
 
93
 
@@ -117,8 +145,14 @@ def Visual_re_ranker(caption_man, caption_woman, context_label, context_prob):
117
  sim_w = get_sim(sim_w)
118
 
119
 
120
- LM_man = cloze_prob(caption_man)
121
- LM_woman = cloze_prob(caption_woman)
 
 
 
 
 
 
122
  #LM = scorer.sentence_score(caption, reduce="mean")
123
  score_man = pow(float(LM_man),pow((1-float(sim_m))/(1+ float(sim_m)),1-float(context_prob)))
124
  score_woman = pow(float(LM_woman),pow((1-float(sim_w))/(1+ float(sim_w)),1-float(context_prob)))
 
59
  #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
60
 
61
 
62
+
63
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
64
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
65
+
66
+ def sentence_prob_mean(text):
67
+ # Tokenize the input text and add special tokens
68
+ input_ids = tokenizer.encode(text, return_tensors='pt')
69
+
70
+ # Obtain model outputs
71
+ with torch.no_grad():
72
+ outputs = model(input_ids, labels=input_ids)
73
+ logits = outputs.logits # logits are the model outputs before applying softmax
74
+
75
+ # Shift logits and labels so that tokens are aligned:
76
+ shift_logits = logits[..., :-1, :].contiguous()
77
+ shift_labels = input_ids[..., 1:].contiguous()
78
+
79
+ # Calculate the softmax probabilities
80
+ probs = softmax(shift_logits, dim=-1)
81
+
82
+ # Gather the probabilities of the actual token IDs
83
+ gathered_probs = torch.gather(probs, 2, shift_labels.unsqueeze(-1)).squeeze(-1)
84
+
85
+ # Compute the mean probability across the tokens
86
+ mean_prob = torch.mean(gathered_probs).item()
87
+
88
+
89
+
90
+ # def cloze_prob(text):
91
+
92
+ # whole_text_encoding = tokenizer.encode(text)
93
+ # text_list = text.split()
94
+ # stem = ' '.join(text_list[:-1])
95
+ # stem_encoding = tokenizer.encode(stem)
96
+ # cw_encoding = whole_text_encoding[len(stem_encoding):]
97
+ # tokens_tensor = torch.tensor([whole_text_encoding])
98
 
99
+ # with torch.no_grad():
100
+ # outputs = model(tokens_tensor)
101
+ # predictions = outputs[0]
102
+
103
+ # logprobs = []
104
+ # start = -1-len(cw_encoding)
105
+ # for j in range(start,-1,1):
106
+ # raw_output = []
107
+ # for i in predictions[-1][j]:
108
+ # raw_output.append(i.item())
109
 
110
+ # logprobs.append(np.log(softmax(raw_output)))
111
 
112
 
113
+ # conditional_probs = []
114
+ # for cw,prob in zip(cw_encoding,logprobs):
115
+ # conditional_probs.append(prob[cw])
116
 
117
 
118
+ # return np.exp(np.sum(conditional_probs))
119
 
120
 
121
 
 
145
  sim_w = get_sim(sim_w)
146
 
147
 
148
+ LM_man = sentence_prob_mean(caption_man)
149
+ LM_woman = sentence_prob_mean(caption_woman)
150
+
151
+ # LM_man = cloze_prob(caption_man)
152
+ # LM_woman = cloze_prob(caption_woman)
153
+
154
+ )
155
+
156
  #LM = scorer.sentence_score(caption, reduce="mean")
157
  score_man = pow(float(LM_man),pow((1-float(sim_m))/(1+ float(sim_m)),1-float(context_prob)))
158
  score_woman = pow(float(LM_woman),pow((1-float(sim_w))/(1+ float(sim_w)),1-float(context_prob)))