|
import os
|
|
import re
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from lavis.models import load_model_and_preprocess
|
|
from lavis.processors import load_processor
|
|
from lavis.common.registry import registry
|
|
from torch.nn import functional as F
|
|
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
|
import numpy as np
|
|
import pandas as pd
|
|
import time
|
|
from fuzzywuzzy import process
|
|
from multiprocessing import Pool, Queue, Process
|
|
import difflib
|
|
import Levenshtein
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
def txt_map(x, txt_dict):
|
|
if type(x) == str:
|
|
x = eval(x)
|
|
x_ = []
|
|
for i in x:
|
|
if i in txt_dict:
|
|
x_.append(txt_dict[i])
|
|
else:
|
|
x_.append(i)
|
|
return x_
|
|
|
|
|
|
def levenshtein_sim(text, label):
|
|
all_s = []
|
|
for x in label:
|
|
s = 0
|
|
for y in text:
|
|
temp = Levenshtein.ratio(x, y)
|
|
if temp > s:
|
|
s = temp
|
|
all_s.append(s)
|
|
all_s = [round(i, 3) for i in all_s]
|
|
return all_s
|
|
|
|
|
|
def func(text, label):
|
|
all_s = []
|
|
for x in text:
|
|
s = 0
|
|
for y in label:
|
|
temp = Levenshtein.ratio(x, y)
|
|
if temp > s:
|
|
s = temp
|
|
all_s.append(s)
|
|
all_s = [round(i, 3) for i in all_s]
|
|
return all_s
|
|
|
|
|
|
def stage2_output(df_test, return_num_txt=1):
|
|
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False,
|
|
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth',
|
|
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '',
|
|
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True,
|
|
'max_protein_len': 600,
|
|
'max_txt_len': 256}
|
|
|
|
model_cls = registry.get_model_class(config['arch'])
|
|
model = model_cls.from_config(config)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
images = df_test['protein'].tolist()
|
|
n = len(images)
|
|
bsz = 8
|
|
iter = n // bsz + 1
|
|
if n > 0:
|
|
for i in range(iter):
|
|
image = images[i * bsz: min(n, (i + 1) * bsz)]
|
|
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
|
|
|
with model.maybe_autocast():
|
|
_, _, batch_tokens = model.visual_encoder(image)
|
|
image_embeds = \
|
|
model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[
|
|
"representations"][model.vis_layers].contiguous()
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
|
|
|
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_output = model.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
inputs_opt = model.opt_proj(query_output.last_hidden_state)
|
|
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device)
|
|
|
|
model.opt_tokenizer.padding_side = "right"
|
|
|
|
text = ['' for i in range(len(image))]
|
|
opt_tokens = model.opt_tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
truncation=True,
|
|
max_length=model.max_txt_len,
|
|
).to(device)
|
|
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
|
|
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
|
|
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)
|
|
num_txt = 5
|
|
with model.maybe_autocast():
|
|
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1,
|
|
max_length=256,
|
|
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118,
|
|
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.)
|
|
output_text = model.opt_tokenizer.batch_decode(outputs)
|
|
|
|
output_text = [re.sub('\t', '', str(x)) for x in output_text]
|
|
output_text = [text.strip() for text in output_text]
|
|
output_text_ = []
|
|
for i in range(len(image)):
|
|
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt]))
|
|
|
|
f = open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+')
|
|
for i in range(len(image)):
|
|
f.write(image[i][1] + "|" + output_text_[i] + '\n')
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
split = 'test'
|
|
cat = ''
|
|
fix = ''
|
|
type_fix = '_pretrain'
|
|
if cat == 'bp':
|
|
fix = '_bp'
|
|
if cat == 'cc':
|
|
fix = '_cc'
|
|
|
|
print(device)
|
|
return_num_txt = 1
|
|
|
|
|
|
|
|
print("reading file ...")
|
|
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/{}_sample10000.csv'.format(split),
|
|
usecols=['name', 'protein', 'function'], sep='|')
|
|
|
|
test.columns = ['name', 'protein', 'label']
|
|
|
|
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)):
|
|
os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix))
|
|
print("stage 2 predict starting")
|
|
stage2_output(test)
|
|
print("stage 2 predict completed")
|
|
|
|
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|',
|
|
header=None, on_bad_lines='warn')
|
|
df_pred.columns = ['protein', 'pred']
|
|
df_pred = df_pred.drop_duplicates()
|
|
|
|
|
|
|
|
|
|
data = pd.merge(df_pred, test, on='protein', how='left')
|
|
data = data[data['label'].notnull()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data[['name', 'label', 'pred']].to_csv(
|
|
'/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|')
|
|
|
|
|
|
|
|
|
|
|
|
|