File size: 2,740 Bytes
b830975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import transformers as T
import argparse
import tqdm
from utils import *

def apc(x):
    "Perform average product correct, used for contact prediction."
    a1 = x.sum(-1, keepdims=True)
    a2 = x.sum(-2, keepdims=True)
    a12 = x.sum((-1, -2), keepdims=True)

    avg = a1 * a2
    avg.div_(a12)  # in-place to reduce memory
    normalized = x - avg
    return normalized

def attention_extractor(model, data_loader, device,n_sample):
    model.eval()
    all_attentions = np.zeros([n_sample,12,12,512])
    ns=0
    all_cor = []
    with torch.no_grad():
        for e, x in tqdm.tqdm(enumerate(data_loader)):
            x = x.to(device)
            output, attn = model(x,output_attentions = True)
            attn = torch.stack(attn,axis=1)
            attn = attn.to(torch.float16)
            attn = apc(attn) 
            if len(attn.shape) < 4:
                attn = attn.unsqueeze(0)
            
            for i in range(len(attn)):
                ind = torch.where(output[i]>=1)[0].detach().cpu().numpy().tolist() # peak position
                if len(ind)>0:
                    if len(ind)>1:
                        attn_sum = attn[i,:,:,ind,:]
                        attn_sum = attn_sum.sum(axis=2)
                    else:
                        attn_sum = attn[i,:,:,ind,:]
                        
                    attn_sum = attn_sum.squeeze()
                    all_attentions[ns] = attn_sum.detach().cpu().numpy()
                ns+=1
                
        assert ns == n_sample
    all_attentions = all_attentions.astype("float16")
    
    return all_attentions

def main(args):
    device = torch.device("cuda")
    load_path = args.model_path
    fast_tokenizer = T.BertTokenizer.from_pretrained("./model/") 
    model = load_model(fast_tokenizer, load_path)
    print(model)
    model.to(device)
    model.eval()
    
    val_dataset = SequenceDataset(args.file_path, fast_tokenizer)        
    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle = False,
        drop_last=False,
        num_workers=2
        )
    
    val_attn = attention_extractor(model, val_loader, device, len(val_dataset))
    np.savez_compressed(args.save_file, attn = val_attn)
    
    # np.savez_compressed(f"{args.save_path}/{args.prefix}_apc_attention.npz",attn = val_attn)
        
parser = argparse.ArgumentParser()
parser.add_argument('--file-path', type=str, help="target")
parser.add_argument('--model-path', type=str, help="save path")
parser.add_argument('--save-file', type=str, help="save path")

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)