AhmedSSabir commited on
Commit
5ea9bf1
1 Parent(s): 0b41e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -10
app.py CHANGED
@@ -20,8 +20,8 @@ from sentence_transformers import SentenceTransformer, util
20
 
21
  #model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
22
 
23
- #model_sts = SentenceTransformer('stsb-distilbert-base')
24
- model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
25
  #batch_size = 1
26
  #scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
27
 
@@ -72,11 +72,7 @@ def cloze_prob(text):
72
  text_list = text.split()
73
  stem = ' '.join(text_list[:-1])
74
  stem_encoding = tokenizer.encode(stem)
75
- # cw_encoding is just the difference between whole_text_encoding and stem_encoding
76
- # note: this might not correspond exactly to the word itself
77
  cw_encoding = whole_text_encoding[len(stem_encoding):]
78
- # Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
79
- # Put the whole text encoding into a tensor, and get the model's comprehensive output
80
  tokens_tensor = torch.tensor([whole_text_encoding])
81
 
82
  with torch.no_grad():
@@ -93,10 +89,7 @@ def cloze_prob(text):
93
 
94
  logprobs.append(np.log(softmax(raw_output)))
95
 
96
- # if the critical word is three tokens long, the raw_probabilities should look something like this:
97
- # [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
98
- # Then for the i'th token we want to find its associated probability
99
- # this is just: raw_probabilities[i][token_index]
100
  conditional_probs = []
101
  for cw,prob in zip(cw_encoding,logprobs):
102
  conditional_probs.append(prob[cw])
 
20
 
21
  #model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
22
 
23
+ model_sts = SentenceTransformer('stsb-distilbert-base')
24
+ #model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
25
  #batch_size = 1
26
  #scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
27
 
 
72
  text_list = text.split()
73
  stem = ' '.join(text_list[:-1])
74
  stem_encoding = tokenizer.encode(stem)
 
 
75
  cw_encoding = whole_text_encoding[len(stem_encoding):]
 
 
76
  tokens_tensor = torch.tensor([whole_text_encoding])
77
 
78
  with torch.no_grad():
 
89
 
90
  logprobs.append(np.log(softmax(raw_output)))
91
 
92
+
 
 
 
93
  conditional_probs = []
94
  for cw,prob in zip(cw_encoding,logprobs):
95
  conditional_probs.append(prob[cw])