Reformer / attention_extractor.py
XLS's picture
Upload 3 files
b830975 verified
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import transformers as T
import argparse
import tqdm
from utils import *
def apc(x):
"Perform average product correct, used for contact prediction."
a1 = x.sum(-1, keepdims=True)
a2 = x.sum(-2, keepdims=True)
a12 = x.sum((-1, -2), keepdims=True)
avg = a1 * a2
avg.div_(a12) # in-place to reduce memory
normalized = x - avg
return normalized
def attention_extractor(model, data_loader, device,n_sample):
model.eval()
all_attentions = np.zeros([n_sample,12,12,512])
ns=0
all_cor = []
with torch.no_grad():
for e, x in tqdm.tqdm(enumerate(data_loader)):
x = x.to(device)
output, attn = model(x,output_attentions = True)
attn = torch.stack(attn,axis=1)
attn = attn.to(torch.float16)
attn = apc(attn)
if len(attn.shape) < 4:
attn = attn.unsqueeze(0)
for i in range(len(attn)):
ind = torch.where(output[i]>=1)[0].detach().cpu().numpy().tolist() # peak position
if len(ind)>0:
if len(ind)>1:
attn_sum = attn[i,:,:,ind,:]
attn_sum = attn_sum.sum(axis=2)
else:
attn_sum = attn[i,:,:,ind,:]
attn_sum = attn_sum.squeeze()
all_attentions[ns] = attn_sum.detach().cpu().numpy()
ns+=1
assert ns == n_sample
all_attentions = all_attentions.astype("float16")
return all_attentions
def main(args):
device = torch.device("cuda")
load_path = args.model_path
fast_tokenizer = T.BertTokenizer.from_pretrained("./model/")
model = load_model(fast_tokenizer, load_path)
print(model)
model.to(device)
model.eval()
val_dataset = SequenceDataset(args.file_path, fast_tokenizer)
val_loader = DataLoader(
val_dataset,
batch_size=4,
shuffle = False,
drop_last=False,
num_workers=2
)
val_attn = attention_extractor(model, val_loader, device, len(val_dataset))
np.savez_compressed(args.save_file, attn = val_attn)
# np.savez_compressed(f"{args.save_path}/{args.prefix}_apc_attention.npz",attn = val_attn)
parser = argparse.ArgumentParser()
parser.add_argument('--file-path', type=str, help="target")
parser.add_argument('--model-path', type=str, help="save path")
parser.add_argument('--save-file', type=str, help="save path")
if __name__ == "__main__":
args = parser.parse_args()
main(args)