from pathlib import Path from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F import time from models.polybert import PolyEncoder from models.training import BaseModel from models.utils import decrypt_checkpoint, load_private_key_from_file import argparse from tqdm import tqdm from models.utils import Config from models.plm import EsmModelInfo, get_model import pandas as pd if __name__ == "__main__": # fmt: off parser = argparse.ArgumentParser(description="Predict plastic degradation") parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint") parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D') parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None) parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv') parser.add_argument("--attn", action='store_true', help="Save attention weights to files") # fmt: on args = parser.parse_args() info = EsmModelInfo(args.plm) plm_dim = info['dim']*2 pbert_dim = 600 dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev) # load weights private_key = load_private_key_from_file() state_dict = decrypt_checkpoint(args.ckpt, private_key) state_dict = { k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')} model.load_state_dict(state_dict) model.eval() print(f'Load predictor from {args.ckpt}') plm_func = get_model(args.plm, 'cuda') print(f'Loaded PLM model {args.plm}') polybert_func = PolyEncoder() print('Loaded PolyEncoder model') outfile = Path( 'predictions.csv' if args.output is None else args.output) # get protein embedding with torch.no_grad(), torch.inference_mode(): df = pd.read_csv(args.csv) probs = [] running_time = [] for i, row in tqdm(df.iterrows()): start_time = time.time() seq = row['sequence'].upper() poly = row['polymer'] seq_emb = plm_func([seq]).to(dev) seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0) poly_emb = polybert_func([poly]).to(dev) logits, p_weights, l_weights = model((seq_emb, poly_emb)) prob = F.softmax(logits, dim=-1)[:, 1].item() if args.attn: outfile.with_suffix('.attn').mkdir( parents=True, exist_ok=True) torch.save( (p_weights, l_weights), outfile.with_suffix('.attn') / f'{i}.pt') probs.append(prob) running_time.append(time.time() - start_time) df['prob'] = probs df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No') df['time'] = running_time # move pred and prob to the front df = df[['pred', 'prob'] + [col for col in df.columns if col not in ['pred', 'prob']]] df.to_csv(outfile, index=False) print(f'Predictions saved to {outfile}') print(f'Attention weights saved to current directory as .pt')