Update modeling_prot2text.py
Browse files- modeling_prot2text.py +2 -18
modeling_prot2text.py
CHANGED
@@ -323,8 +323,8 @@ class Prot2TextModel(PreTrainedModel):
|
|
323 |
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
324 |
encoder_outputs=encoder_state,
|
325 |
use_cache=True,
|
326 |
-
output_attentions=
|
327 |
-
output_scores=
|
328 |
return_dict_in_generate=True,
|
329 |
encoder_attention_mask=inputs['attention_mask'],
|
330 |
length_penalty=1.0,
|
@@ -333,22 +333,6 @@ class Prot2TextModel(PreTrainedModel):
|
|
333 |
num_beams=1)
|
334 |
|
335 |
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
336 |
-
print(tok_ids.get('scores')[0].size())
|
337 |
-
m = torch.nn.Softmax()
|
338 |
-
att_w = []
|
339 |
-
print(len(gpdb.sequence[0]))
|
340 |
-
score = 0
|
341 |
-
for i in range(len(tok_ids.get('cross_attentions'))):
|
342 |
-
att_w.append(torch.mul(tok_ids.get('cross_attentions')[i][-1].squeeze().mean(dim=0), inputs['attention_mask'][-1].squeeze())[:len(gpdb.sequence[0])].tolist())
|
343 |
-
score += np.log(torch.max(m(tok_ids.get('scores')[i]).squeeze()).item())
|
344 |
-
score = score / len(tok_ids.get('cross_attentions'))
|
345 |
-
# print(str(score))
|
346 |
-
|
347 |
-
# import seaborn as sns
|
348 |
-
# import matplotlib.pylab as plt
|
349 |
-
# plt.figure().set_figwidth(150)
|
350 |
-
# ax = sns.heatmap(att_w, cmap="YlGnBu", robust=True, xticklabels=gpdb.sequence[0])#, yticklabels=generated[0])
|
351 |
-
# plt.savefig("seaborn_plot.png")
|
352 |
|
353 |
os.remove(structure_filename)
|
354 |
os.remove(graph_filename)
|
|
|
323 |
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
324 |
encoder_outputs=encoder_state,
|
325 |
use_cache=True,
|
326 |
+
output_attentions=False,
|
327 |
+
output_scores=False,
|
328 |
return_dict_in_generate=True,
|
329 |
encoder_attention_mask=inputs['attention_mask'],
|
330 |
length_penalty=1.0,
|
|
|
333 |
num_beams=1)
|
334 |
|
335 |
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
os.remove(structure_filename)
|
338 |
os.remove(graph_filename)
|