|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import transformers as T |
|
import argparse |
|
import tqdm |
|
from utils import * |
|
|
|
def apc(x): |
|
"Perform average product correct, used for contact prediction." |
|
a1 = x.sum(-1, keepdims=True) |
|
a2 = x.sum(-2, keepdims=True) |
|
a12 = x.sum((-1, -2), keepdims=True) |
|
|
|
avg = a1 * a2 |
|
avg.div_(a12) |
|
normalized = x - avg |
|
return normalized |
|
|
|
def attention_extractor(model, data_loader, device,n_sample): |
|
model.eval() |
|
all_attentions = np.zeros([n_sample,12,12,512]) |
|
ns=0 |
|
all_cor = [] |
|
with torch.no_grad(): |
|
for e, x in tqdm.tqdm(enumerate(data_loader)): |
|
x = x.to(device) |
|
output, attn = model(x,output_attentions = True) |
|
attn = torch.stack(attn,axis=1) |
|
attn = attn.to(torch.float16) |
|
attn = apc(attn) |
|
if len(attn.shape) < 4: |
|
attn = attn.unsqueeze(0) |
|
|
|
for i in range(len(attn)): |
|
ind = torch.where(output[i]>=1)[0].detach().cpu().numpy().tolist() |
|
if len(ind)>0: |
|
if len(ind)>1: |
|
attn_sum = attn[i,:,:,ind,:] |
|
attn_sum = attn_sum.sum(axis=2) |
|
else: |
|
attn_sum = attn[i,:,:,ind,:] |
|
|
|
attn_sum = attn_sum.squeeze() |
|
all_attentions[ns] = attn_sum.detach().cpu().numpy() |
|
ns+=1 |
|
|
|
assert ns == n_sample |
|
all_attentions = all_attentions.astype("float16") |
|
|
|
return all_attentions |
|
|
|
def main(args): |
|
device = torch.device("cuda") |
|
load_path = args.model_path |
|
fast_tokenizer = T.BertTokenizer.from_pretrained("./model/") |
|
model = load_model(fast_tokenizer, load_path) |
|
print(model) |
|
model.to(device) |
|
model.eval() |
|
|
|
val_dataset = SequenceDataset(args.file_path, fast_tokenizer) |
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=4, |
|
shuffle = False, |
|
drop_last=False, |
|
num_workers=2 |
|
) |
|
|
|
val_attn = attention_extractor(model, val_loader, device, len(val_dataset)) |
|
np.savez_compressed(args.save_file, attn = val_attn) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--file-path', type=str, help="target") |
|
parser.add_argument('--model-path', type=str, help="save path") |
|
parser.add_argument('--save-file', type=str, help="save path") |
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
main(args) |
|
|