| import torch.nn as nn |
| import torch |
| import os |
| import mini3di |
| import numpy as np |
| import torch.nn.functional as F |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from src.data.protein_dataset import dynamic_pad |
| from transformers import AutoTokenizer, AutoModelForMaskedLM, LlamaForCausalLM, LlamaTokenizer, T5Tokenizer, T5EncoderModel, AutoModelForCausalLM, AutoModel |
| from src.data.esm.sdk.api import LogitsConfig |
| from model_zoom.procyon.model.model_unified import UnifiedProCyon |
| from model_zoom.progen2.modeling_progen import ProGenForCausalLM |
| from tokenizers import Tokenizer |
| from model_zoom.GearNet.data.protein import Protein |
| from model_zoom.GearNet.data.transform import ProteinView |
| from model_zoom.GearNet.data.transform import Compose |
| from model_zoom.GearNet.data.geo_graph import GraphConstruction |
| from model_zoom.GearNet.data.function import AlphaCarbonNode, SpatialEdge, KNNEdge, SequentialEdge |
| from model_zoom.GearNet.gearnet import GeometryAwareRelationalGraphNeuralNetwork |
| from model_zoom.esm.utils.sampling import _BatchedESMProteinTensor |
| from model_zoom.ProTrek.model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel |
| from vplm import TransformerForMaskedLM, TransformerConfig |
| from vplm import VPLMTokenizer |
| from peft import TaskType, get_peft_model |
| from peft import IA3Config, AdaLoraConfig, LoraConfig |
| from vplm import VPLMTokenizer |
|
|
|
|
| MODEL_ZOOM_PATH = '/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom' |
|
|
| class BaseProteinModel(nn.Module): |
| def __init__(self, device, **kwargs): |
| super().__init__() |
| self.device = device |
| |
| def construct_batch(self, data, batch_size, task_name=None): |
| raise NotImplementedError |
| |
| def setup_peft(self, peft_type): |
| raise NotImplementedError |
| |
| def get_tokenizer(self): |
| raise NotImplementedError |
| |
| def forward(self, batch): |
| raise NotImplementedError |
|
|
| class UtilsModel: |
| def __init__(self): |
| super().__init__() |
|
|
| def post_process_cpu(self, batch, embeddings, attention_masks, start, ends, task_type='binary_classification'): |
|
|
| |
| results = [] |
| for i, end in enumerate(ends): |
| end = int(end.item()) |
| embedding = embeddings[i][start:end].cpu() |
| name = batch['name'][i] |
| attention_mask = attention_masks[i][start:end].cpu() |
| label = torch.tensor(batch['label'][i]) |
| |
| results.append({'name': name, |
| 'embedding': embedding, |
| 'attention_mask': attention_mask.bool(), |
| 'label': label} ) |
| return results |
|
|
| def pad_data(self, data, dim=0, pad_value=0, max_length=1022): |
| if data.shape[dim] < max_length: |
| data = self.dynamic_pad(data, [0, max_length-data.shape[dim]], dim=dim, pad_value=pad_value) |
| else: |
| |
| |
| start = 0 |
| end = start + max_length |
| |
| slices = [slice(None)] * data.ndim |
| slices[dim] = slice(start, end) |
| data = data[tuple(slices)] |
| return data |
|
|
| def dynamic_pad(self, tensor, pad_size, dim=0, pad_value=0): |
| |
| shape = list(tensor.shape) |
| num_dims = len(shape) |
| |
| |
| pad = [0] * (2 * num_dims) |
| prev_pad_size, post_pad_size = pad_size |
| pad_index = 2 * (num_dims - dim - 1) |
| pad[pad_index] = prev_pad_size |
| pad[pad_index + 1] = post_pad_size |
|
|
| |
| padded_tensor = F.pad(tensor, pad, mode="constant", value=pad_value) |
| return padded_tensor |
| |
| class ESM2Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, model_path = 'esm2_650m',**kwargs): |
| super().__init__(device) |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| self.model = AutoModelForMaskedLM.from_pretrained(f"{MODEL_ZOOM_PATH}/{model_path}").to(self.device) |
| self.tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/{model_path}") |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
| |
| def setup_peft(self, peft_type="lora", **kwargs): |
| if peft_type == "freeze": |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| else: |
| if peft_type == "lora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| inference_mode=False, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["query", "value"], |
| ) |
| elif peft_type == "ia3": |
| peft_config = IA3Config( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| target_modules=["query", "value", "dense"], |
| feedforward_modules=["dense"], |
| ) |
| elif peft_type == "dora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| use_dora=True, |
| inference_mode=False, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["query", "value"], |
| ) |
| elif peft_type == "adalora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = AdaLoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_r=4, |
| init_r=12, |
| beta1=0.85, beta2=0.85, |
| tinit=200, |
| tfinal=1000, |
| deltaT=10, |
| target_modules=["query", "value"], |
| ) |
| self.model = get_peft_model(self.model, peft_config) |
| |
| def construct_batch(self, batch): |
| MAXLEN = self.max_length |
| max_length_batch = min(max([len(sample['seq']) for sample in batch]) + 2, self.max_length + 2) |
| result = { |
| 'name': [], |
| 'seq': [], |
| 'attention_mask': [], |
| 'label': [] |
| } |
| for sample in batch: |
| seq_token = torch.tensor(self.tokenizer.encode(sample['seq']))[:MAXLEN] |
| attention_mask = torch.zeros(max_length_batch) |
| attention_mask[:len(seq_token)] = 1 |
| seq_token = self.pad_data(seq_token, dim=0, max_length=max_length_batch) |
| result['name'].append(sample['name']) |
| result['seq'].append(seq_token) |
| result['attention_mask'].append(attention_mask) |
| result['label'].append(sample['label']) |
|
|
| result['seq'] = torch.stack(result['seq'], dim=0).to(self.device) |
| result['attention_mask'] = torch.stack(result['attention_mask'], dim=0).to(self.device) |
| |
| return result |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_prob=False, return_logits=False, **kwargs): |
| attention_mask = batch['attention_mask'] |
| outputs = self.model.esm( |
| batch['seq'], |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
|
|
| if return_logits: |
| logits = self.model.lm_head(outputs.last_hidden_state) |
| return logits |
| |
| if return_prob: |
| logits = self.model.lm_head(outputs.last_hidden_state) |
| probs = F.softmax(logits, dim=-1) |
| return probs |
| |
| embeddings = outputs.last_hidden_state |
| ends = attention_mask.sum(dim=-1)-1 |
| start = 1 |
| if post_process: |
| result = self.post_process_cpu(batch, embeddings, attention_mask, start, ends, task_type=task_type) |
| else: |
| result = embeddings |
| return result |
|
|
| class SmilesModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, **kwargs): |
| super().__init__(device) |
|
|
|
|
| def construct_batch(self, batch): |
| result = {'smiles': []} |
| for sample in batch: |
| mol = Chem.MolFromSmiles(sample['smiles']) |
| if mol is not None: |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) |
| smiles = torch.tensor([int(ele) for ele in list(fp.ToBitString())]).float() |
| else: |
| smiles = torch.tensor([0]*2048).float() |
| result['smiles'].append(smiles) |
| return result |
| |
| def forward(self, batch, post_process=True, task_type='binary_classification'): |
| return batch |
| |
| class ESM3Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, sequence_only=False): |
| super().__init__(device) |
| from model_zoom.esm.models.esm3 import ESM3 |
| self.model = ESM3.from_pretrained("esm3-sm-open-v1").to(self.device) |
| self.sequence_only = sequence_only |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.model.tokenizers.sequence |
|
|
| def construct_batch(self, batch): |
| from model_zoom.esm.utils import encoding |
| from model_zoom.esm.utils.misc import stack_variable_length_tensors |
| addBOS = 1 |
| pad_id = self.model.tokenizers.sequence.pad_token_id |
| max_len = min(max([len(s['seq']) for s in batch]) + 2*addBOS, self.max_length + 2*addBOS) |
| names, prot_tensors, labels, masks = [], [], [], [] |
| sequence_list, coordinates_list, structure_tokens_batch = [], [], [] |
| for sample in batch: |
| seq, coords = sample['seq'], sample['X'] |
| seq_tokenizer = self.model.tokenizers.sequence |
| struct_tokenizer = self.model.tokenizers.structure |
| |
| seq_tok = encoding.tokenize_sequence(seq, seq_tokenizer, add_special_tokens=True) |
| with torch.no_grad(): |
| coords_tok, _plddt, struct_tok = encoding.tokenize_structure( |
| np.array(coords), |
| self.model.get_structure_encoder(), |
| struct_tokenizer, |
| add_special_tokens=True |
| ) |
| coords_tok, struct_tok = torch.tensor(coords_tok), torch.tensor(struct_tok) |
| mask = torch.zeros(max_len) |
| mask[:seq_tok.shape[0]] = 1 |
| seq_tok = self.pad_data(seq_tok, dim=0, pad_value=pad_id, max_length=max_len) |
| struct_tok = self.pad_data(struct_tok, dim=0, pad_value=pad_id, max_length=max_len) |
| coords_tok = dynamic_pad(coords_tok, [addBOS, addBOS], dim=0, pad_value=0) |
| coords_tok = self.pad_data(coords_tok, dim=0, max_length=max_len) |
| |
| sequence_list.append(seq_tok) |
| coordinates_list.append(coords_tok) |
| structure_tokens_batch.append(struct_tok) |
| names.append(sample['name']) |
| masks.append(mask) |
| labels.append(sample['label']) |
| |
| sequence_tokens = stack_variable_length_tensors( |
| sequence_list, |
| constant_value=pad_id, |
| ).to(self.device) |
| |
| structure_tokens_batch = stack_variable_length_tensors( |
| structure_tokens_batch, |
| constant_value=pad_id, |
| ).to(self.device) |
| |
| coordinates_batch = stack_variable_length_tensors( |
| coordinates_list, |
| constant_value=pad_id, |
| ).to(self.device) |
| protein_tensor = _BatchedESMProteinTensor(sequence=sequence_tokens, |
| structure=structure_tokens_batch, coordinates=coordinates_batch).to(self.device) |
| return { |
| 'name': names, |
| 'seq': protein_tensor, |
| 'attention_mask': torch.stack([m.bool() for m in masks]).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| tens, mask = batch['seq'], batch['attention_mask'] |
| if return_logits: |
| out = self.model.logits( |
| tens, LogitsConfig( |
| sequence=True, structure=False, secondary_structure=False, |
| sasa=False, function=False, residue_annotations=False, return_embeddings=True |
| ) |
| ) |
| logits = out.logits.sequence |
| return logits |
|
|
| out = self.model.logits( |
| tens, LogitsConfig( |
| sequence=True, structure=True, secondary_structure=True, |
| sasa=True, function=True, residue_annotations=True, return_embeddings=True |
| ) |
| ) |
|
|
| embeddings = out.embeddings |
| ends = mask.sum(dim=-1) - 1 |
| start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, embeddings, mask, start, ends, task_type) |
| return embeddings |
| |
| class ESMC600MModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| from model_zoom.esm.models.esmc import ESMC |
| self.model = ESMC.from_pretrained("esmc_600m").to(self.device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.model.tokenizer |
|
|
| def construct_batch(self, batch): |
| from model_zoom.esm.utils.misc import stack_variable_length_tensors |
| addBOS = 1 |
| pad_id = self.model.tokenizer.pad_token_id |
| max_len = min(max([self.model._tokenize([s['seq']]).shape[1]-2 for s in batch]) + 2*addBOS, self.max_length+2*addBOS) |
| names, prots, masks, labels = [], [], [], [] |
| token_ids_list = [] |
| for sample in batch: |
| seq = sample['seq'] |
| token_ids = self.model._tokenize([seq]).flatten() |
| mask = torch.zeros(max_len) |
| mask[:len(token_ids)] = 1 |
| token_ids = self.pad_data(token_ids, dim=0, pad_value=pad_id, max_length=max_len) |
| token_ids_list.append(token_ids) |
| names.append(sample['name']) |
| masks.append(mask) |
| labels.append(sample['label']) |
| sequence_tokens = stack_variable_length_tensors( |
| token_ids_list, |
| constant_value=self.model.tokenizer.pad_token_id, |
| ) |
| protein_tensor = _BatchedESMProteinTensor(sequence=sequence_tokens).to(self.device) |
| return { |
| 'name': names, |
| 'seq': protein_tensor, |
| 'attention_mask': torch.stack([m.bool() for m in masks]).to(self.device), |
| 'label': labels |
| } |
| |
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| tens, mask = batch['seq'], batch['attention_mask'] |
| outputs = self.model.logits(tens, LogitsConfig(sequence=True, return_embeddings=True)) |
|
|
| if return_logits: |
| return outputs.logits.sequence |
| embeddings = outputs.embeddings |
| ends = mask.sum(dim=-1) - 1 |
| start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, embeddings, mask, start, ends, task_type) |
| return embeddings |
| |
|
|
| |
| class ProCyonModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, sequence_only=False): |
| super().__init__(device) |
| protein_view_transform = ProteinView(view='residue') |
| self.transform = Compose([protein_view_transform]) |
| self.graph_construction_model = GraphConstruction( |
| node_layers=[AlphaCarbonNode()], |
| edge_layers=[SpatialEdge(radius=10.0, min_distance=5), |
| KNNEdge(k=10, min_distance=5), |
| SequentialEdge(max_distance=2)], |
| edge_feature="gearnet" |
| ) |
| self.gearnet_edge = GeometryAwareRelationalGraphNeuralNetwork(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512], |
| num_relation=7, edge_input_dim=59, num_angle_bin=8, |
| batch_norm=True, concat_hidden=True, short_cut=True, readout="sum" |
| ) |
| ckpt = torch.load("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/GearNet/mc_gearnet_edge.pth") |
| self.gearnet_edge.load_state_dict(ckpt) |
| self.gearnet_edge = self.gearnet_edge.to(self.device) |
| self.gearnet_edge.eval() |
| os.environ["HOME_DIR"] = MODEL_ZOOM_PATH |
| os.environ["DATA_DIR"] = "/nfs_beijing/wanghao/2025-onesystem/vllm/ProCyon-Instruct" |
| os.environ["LLAMA3_PATH"] = "/nfs_beijing/wanghao/2025-onesystem/vllm/Meta-Llama-3-8B" |
| procyon_ckpt = '/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom/procyon/model_weights/ProCyon-Full' |
| self.esm_pretrain_model = AutoModelForMaskedLM.from_pretrained(f"{MODEL_ZOOM_PATH}/esm2_3b").to(self.device) |
| self.esm_pretrain_model.eval() |
| self.esm_tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/esm2_3b") |
| |
| self.model, _ = UnifiedProCyon.from_pretrained( |
| pretrained_weights_dir=procyon_ckpt, |
| checkpoint_dir=procyon_ckpt |
| ) |
| self.model = self.model.to(self.device) |
|
|
| self.max_length = max_length |
| self.sequence_only = sequence_only |
|
|
| def get_tokenizer(self): |
| return self.esm_tokenizer |
|
|
| def construct_batch(self, batch): |
| names, seqs, structs, labels = [], [], [], [] |
| for sample in batch: |
| try: |
| seqs_list = sample['seq'] if isinstance(sample['seq'], list) else [sample['seq']] |
| pdbs = sample['pdb_path'] if isinstance(sample['pdb_path'], list) else [sample['pdb_path']] |
| seq_embs, struct_embs = [], [] |
| for s, p in zip(seqs_list, pdbs): |
| toks = self.esm_tokenizer([s], return_tensors='pt', padding=True, max_length=self.max_length, truncation=True) |
| out = self.esm_pretrain_model.esm(toks.input_ids.to(self.device), attention_mask=toks.attention_mask.to(self.device), return_dict=True) |
| seq_embs.append(out.last_hidden_state.squeeze(0).mean(0)) |
| prot = Protein.from_pdb(p, bond_feature="length", residue_feature="symbol") |
| prot = self.transform({"graph": prot})["graph"] |
| packed = Protein.pack([prot]) |
| protein = self.graph_construction_model(packed).to(self.device) |
| with torch.no_grad(): |
| gea = self.gearnet_edge(protein, protein.node_feature.float()) |
| struct_embs.append(gea["graph_feature"].flatten()) |
| names.append(sample['name']) |
| seqs.append(torch.cat(seq_embs, dim=-1)) |
| structs.append(torch.cat(struct_embs, dim=-1)) |
| labels.append(sample['label']) |
| except Exception as e: |
| print(f"Error processing sample {sample['name']}: {e}") |
| continue |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'X': torch.stack(structs).unsqueeze(1).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| seq_emb, struct_emb = batch['seq'], batch['X'] |
| aaseq = self.model.token_projectors['aaseq'](seq_emb) |
| struct_proj = self.model.token_projectors['prot_structure'](struct_emb) |
| B = aaseq.shape[0] |
| instr = ["Describe the following protein with functions: <|protein|> <|struct|>"] * B |
| input_ids, attn = self.model._prepare_text_inputs_and_tokenize(instr, [[]]*B, no_pad=True) |
| input_ids, attn = input_ids.to(self.device), attn.to(self.device) |
| if self.sequence_only: |
| embeds, _ = self.model._prepare_input_embeddings(input_ids, protein_soft_tokens=aaseq) |
| else: |
| embeds, _ = self.model._prepare_input_embeddings(input_ids, protein_soft_tokens=aaseq, protein_struct_tokens=struct_proj) |
| mask = ~(input_ids == self.model.tokenizer.pad_token_id) |
| out = self.model.text_encoder(input_embeds=embeds, attn_masks=attn) |
| h = out.hidden_states[-1] |
| ends = mask.sum(dim=-1) |
| start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, h, mask, start, ends, task_type) |
| return h |
|
|
| |
| class GearNetModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| pv = ProteinView(view='residue') |
| self.transform = Compose([pv]) |
| self.graph_construction = GraphConstruction( |
| node_layers=[AlphaCarbonNode()], edge_layers=[SpatialEdge(radius=10.0, min_distance=5), KNNEdge(k=10, min_distance=5), SequentialEdge(max_distance=2)], edge_feature="gearnet" |
| ) |
| self.gearnet = GeometryAwareRelationalGraphNeuralNetwork( |
| input_dim=21, hidden_dims=[512]*6, num_relation=7, edge_input_dim=59, num_angle_bin=8, |
| batch_norm=True, concat_hidden=True, short_cut=True, readout="sum" |
| ) |
| ckpt = torch.load(f"{MODEL_ZOOM_PATH}/GearNet/mc_gearnet_edge.pth") |
| self.gearnet.load_state_dict(ckpt) |
| self.gearnet = self.gearnet.to(self.device).eval() |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return None |
| |
| def construct_batch(self, batch): |
| names, embeddings, attention_masks, labels = [], [], [], [] |
| for sample in batch: |
| try: |
| pdbs = sample['pdb_path'] if isinstance(sample['pdb_path'], list) else [sample['pdb_path']] |
| prots = [] |
| for p in pdbs: |
| pr = Protein.from_pdb(p, bond_feature="length", residue_feature="symbol") |
| prots.append(self.transform({"graph": pr})["graph"]) |
|
|
| pack = Protein.pack(prots) |
| max_res = pack.num_residues.max().item() |
| gc = self.graph_construction(pack.to(self.device)) |
| node = self.gearnet(gc.to(self.device), gc.node_feature.float().to(self.device))["node_feature"] |
| splits = torch.cumsum(F.pad(pack.num_residues, (1,0)), dim=0) |
| attention_mask = torch.zeros(len(splits)-1, max_res).to(self.device) |
| embeddings_temp = [] |
| for i in range(len(splits)-1): |
| start, end = splits[i], splits[i+1] |
| embedding = node[start:end] |
| attention_mask[i, :embedding.shape[0]] = 1 |
| embedding = self.pad_data(embedding, dim=0, max_length=max_res) |
| embeddings_temp.append(embedding) |
| embeddings_temp = torch.stack(embeddings_temp) |
| embeddings.append(embeddings_temp) |
| attention_masks.append(attention_mask) |
| labels.append(sample['label']) |
| names.append(sample['name']) |
| except Exception as e: |
| print(f"Error processing sample {sample['name']}: {e}") |
| continue |
| |
| max_len = max([one.shape[1] for one in embeddings]) |
| |
| |
| embeddings = torch.stack([F.pad(one[0], (0,0,0, max_len-one.shape[1])) for one in embeddings], dim=0) |
| attention_masks = torch.stack([F.pad(one[0], (0, max_len-one.shape[1])) for one in attention_masks], dim=0) |
| return { |
| 'name': names, |
| 'X': embeddings, |
| 'attention_mask': attention_masks, |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| emb = batch['X']; mask = batch['attention_mask'] |
| ends = mask.sum(dim=-1) |
| start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, emb, mask, start, ends, task_type) |
| return emb |
|
|
| |
| class ProLLAMAModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| llama_path = f"{MODEL_ZOOM_PATH}/ProLLaMA" |
| self.model = LlamaForCausalLM.from_pretrained(llama_path).to(self.device) |
| self.tokenizer = LlamaTokenizer.from_pretrained(llama_path) |
| self.max_length = max_length |
| |
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = min( |
| max(len(s) for sample in batch for s in (sample['seq'] if isinstance(sample['seq'], list) else [sample['seq']])) + 2, |
| self.max_length + 2 |
| ) |
| names, seqs, masks, labels = [], [], [], [] |
| for sample in batch: |
| seqs_list = sample['seq'] if isinstance(sample['seq'], list) else [sample['seq']] |
| tok_ids, m = [], [] |
| for s in seqs_list: |
| s2 = f"[Determine superfamily] Seq=<{s}>" |
| tid = torch.tensor(self.tokenizer.encode(s2)) |
| mask = torch.zeros(max_len, dtype=torch.bool); mask[:len(tid)] = True |
| tid = self.pad_data(tid, dim=0, max_length=max_len) |
| tok_ids.append(tid); m.append(mask) |
| names.append(sample['name']); seqs.append(torch.hstack(tok_ids)); masks.append(torch.hstack(m)); labels.append(sample['label']) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| seq, mask = batch['seq'], batch['attention_mask'] |
| out = self.model(input_ids=seq, attention_mask=mask, output_hidden_states=True) |
| emb = out.hidden_states[-1].float() |
| ends = mask.sum(dim=-1) ; start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, emb, mask, start, ends, task_type) |
| return emb |
|
|
| |
| class ProSTModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| self.prost = AutoModel.from_pretrained( |
| f"{MODEL_ZOOM_PATH}/protst", |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 |
| ).to(self.device) |
| self.model = self.prost.protein_model |
| self.tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/esm1b_650m") |
| self.max_length = max_length |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| names, seqs, masks, labels = [], [], [], [] |
| max_length_batch = min( |
| max([len(self.tokenizer.encode(sample['seq'], add_special_tokens=False)) for sample in batch]) + 2, |
| self.max_length + 2 |
| ) |
| for sample in batch: |
| seq = sample['seq'][:max_length_batch] |
| tid = torch.tensor(self.tokenizer.encode(seq)) |
| mask = torch.zeros(max_length_batch, dtype=torch.bool); mask[:len(tid)] = True |
| tid = self.pad_data(tid, dim=0, max_length=max_length_batch) |
| names.append(sample['name']); seqs.append(tid); masks.append(mask); labels.append(sample['label']) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| out = self.model(input_ids=batch['seq'], attention_mask=batch['attention_mask'], return_dict=True) |
| emb = out.residue_feature |
| if return_logits: |
| |
| |
| |
| return emb |
| |
| ends = batch['attention_mask'].sum(dim=-1) - 1; start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, emb, batch['attention_mask'], start, ends, task_type) |
| return emb |
|
|
| |
| class ProGen2Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| self.model = ProGenForCausalLM.from_pretrained(f"{MODEL_ZOOM_PATH}/progen2").to(self.device) |
| def create_tokenizer_custom(file): |
| with open(file, 'r') as f: |
| return Tokenizer.from_str(f.read()) |
| self.tokenizer = create_tokenizer_custom(file=f"{MODEL_ZOOM_PATH}/progen2/tokenizer.json") |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = max(len(self.tokenizer.encode(s).ids) for sample in batch for s in ([sample['seq']] if not isinstance(sample['seq'], list) else sample['seq'])) |
| names, seqs, masks, labels = [], [], [], [] |
| for sample in batch: |
| seqs_list = sample['seq'] if isinstance(sample['seq'], list) else [sample['seq']] |
| tids, m = [], [] |
| for s in seqs_list: |
| tok = torch.tensor(self.tokenizer.encode(s).ids) |
| mask = torch.zeros(max_len, dtype=torch.bool); mask[:len(tok)] = True |
| tok = self.pad_data(tok, dim=0, max_length=max_len) |
| mask = self.pad_data(mask, dim=0, max_length=max_len) |
| tids.append(tok); m.append(mask) |
| stacked = torch.hstack(tids)[:self.max_length]; mstack = torch.hstack(m)[:self.max_length] |
| names.append(sample['name']); seqs.append(stacked); masks.append(mstack); labels.append(sample['label']) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| out = self.model.transformer(batch['seq'], return_dict=True) |
| emb = out.last_hidden_state |
| ends = batch['attention_mask'].sum(dim=-1) - 1; start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, emb, batch['attention_mask'], start, ends, task_type) |
| return emb |
|
|
| |
| class ProstT5Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, sequence_only=False): |
| super().__init__(device) |
| self.tokenizer = T5Tokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/ProstT5", do_lower_case=False, legacy=False) |
| self.model = T5EncoderModel.from_pretrained(f"{MODEL_ZOOM_PATH}/ProstT5").to(self.device) |
| self.encoder_3di = mini3di.Encoder() |
| self.max_length = max_length |
| self.sequence_only = sequence_only |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
| |
| def setup_peft(self, peft_type="lora", **kwargs): |
| if peft_type == "freeze": |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| else: |
| if peft_type == "lora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| inference_mode=False, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["q", "v"], |
| ) |
| elif peft_type == "ia3": |
| peft_config = IA3Config( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| target_modules=["q", "v", "wi", "wo"], |
| feedforward_modules=["wi", "wo"], |
| ) |
| elif peft_type == "dora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| use_dora=True, |
| inference_mode=False, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["q", "v"], |
| ) |
| elif peft_type == "adalora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = AdaLoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_r=4, |
| init_r=12, |
| beta1=0.85, beta2=0.85, |
| tinit=200, |
| tfinal=1000, |
| deltaT=10, |
| target_modules=["q", "v"], |
| ) |
| self.model = get_peft_model(self.model, peft_config) |
|
|
| def construct_batch(self, batch): |
| import re |
| max_length_batch = min( |
| max([len(sample['seq']) for sample in batch]) + 2, |
| self.max_length + 2 |
| ) |
| names, seqs, masks, labels = [], [], [], [] |
| seq_tokens, attention_masks = [], [] |
| for sample in batch: |
| seq = sample['seq']; X = sample['X'] |
| N, CA, C, CB = X[:,0], X[:,1], X[:,2], X[:,3] |
| attention_mask = torch.zeros(2, max_length_batch, device=self.device) |
| states = self.encoder_3di.encode_atoms( |
| ca=CA.float().cpu().numpy(), |
| cb=CB.float().cpu().numpy(), |
| n=N.float().cpu().numpy(), |
| c=C.float().cpu().numpy(), |
| ) |
| struct_seq = self.encoder_3di.build_sequence(states).lower() |
| if self.sequence_only: |
| sequence_examples = [seq, seq] |
| sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples] |
| sequence_examples = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s |
| for s in sequence_examples |
| ] |
| else: |
| sequence_examples = [seq, struct_seq] |
| sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples] |
| sequence_examples = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s |
| for s in sequence_examples |
| ] |
| seq_token = self.tokenizer.batch_encode_plus(sequence_examples, |
| add_special_tokens=True, |
| padding="longest", |
| return_tensors='pt').to(self.device) |
| attention_mask[:, :seq_token.input_ids.shape[1]] = 1 |
| seq_token = self.pad_data(seq_token.input_ids, dim=1, max_length=max_length_batch) |
| attention_mask = self.pad_data(attention_mask, dim=1, max_length=max_length_batch) |
| seq_tokens.append(seq_token) |
| attention_masks.append(attention_mask) |
| names.append(sample['name']) |
| labels.append(sample['label']) |
|
|
| return { |
| 'name': names, |
| 'seq': torch.cat(seq_tokens, dim=0), |
| 'attention_mask': torch.cat(attention_masks, dim=0), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| seq, attention_mask = batch['seq'], batch['attention_mask'] |
| embedding_repr = self.model( |
| seq, |
| attention_mask=attention_mask |
| ) |
| last = embedding_repr.last_hidden_state |
| |
| |
| |
| B2, L, H = last.size() |
| b = len(batch['name']) |
| last = last.view(b, 2, L, H) |
| emb = torch.cat([last[:,0], last[:,1]], dim=-1) |
| if return_logits: |
| return emb |
| mask = batch['attention_mask'][::2] |
| ends = mask.sum(dim=-1) - 1; start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, emb, mask, start, ends, task_type) |
| return emb |
| |
| |
| class ProtGPT2Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| self.tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/ProtGPT2") |
| self.model = AutoModelForCausalLM.from_pretrained(f"{MODEL_ZOOM_PATH}/ProtGPT2").to(self.device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = max(len(self.tokenizer.encode(s)) for sample in batch for s in ([sample['seq']] if not isinstance(sample['seq'], list) else sample['seq'])) |
| names, seqs, masks, labels = [], [], [], [] |
| for sample in batch: |
| seqs_list = sample['seq'] if isinstance(sample['seq'], list) else [sample['seq']] |
| tids, m = [], [] |
| for s in seqs_list: |
| tok = torch.tensor(self.tokenizer.encode(s)) |
| mask = torch.zeros(max_len, dtype=torch.bool); mask[:len(tok)] = True |
| tok = self.pad_data(tok, dim=0, max_length=max_len) |
| mask = self.pad_data(mask, dim=0, max_length=max_len) |
| tids.append(tok); m.append(mask) |
| names.append(sample['name']); seqs.append(torch.hstack(tids)); masks.append(torch.hstack(m)); labels.append(sample['label']) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| out = self.model(input_ids=batch['seq'], attention_mask=batch['attention_mask'], output_hidden_states=True) |
| emb = out.hidden_states[-1] |
| ends = batch['attention_mask'].sum(dim=-1); start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, emb, batch['attention_mask'], start, ends, task_type) |
| return emb |
|
|
| |
| class ProTrekModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, model_path='protrek_650m'): |
| super().__init__(device) |
| if model_path == 'protrek_650m': |
| config = { |
| "protein_config": f"{MODEL_ZOOM_PATH}/ProTrek/ProTrek_650M_UniRef50/esm2_t33_650M_UR50D", |
| "text_config": f"{MODEL_ZOOM_PATH}/ProTrek/ProTrek_650M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
| "structure_config": f"{MODEL_ZOOM_PATH}/ProTrek/ProTrek_650M_UniRef50/foldseek_t30_150M", |
| "load_protein_pretrained": False, |
| "load_text_pretrained": False, |
| "from_checkpoint": f"{MODEL_ZOOM_PATH}/ProTrek/ProTrek_650M_UniRef50/ProTrek_650M_UniRef50.pt" |
| } |
| if model_path == 'protrek_35m': |
| config = { |
| "protein_config": f"{MODEL_ZOOM_PATH}/protrek_35m/esm2_t12_35M_UR50D", |
| "text_config": f"{MODEL_ZOOM_PATH}/protrek_35m/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", |
| "structure_config": f"{MODEL_ZOOM_PATH}/protrek_35m/foldseek_t12_35M", |
| "load_protein_pretrained": False, |
| "load_text_pretrained": False, |
| "from_checkpoint": f"{MODEL_ZOOM_PATH}/protrek_35m/ProTrek_35M_UniRef50.pt" |
| } |
| self.model = ProTrekTrimodalModel(**config).to(self.device) |
| self.encoder_3di = mini3di.Encoder() |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.model.protein_encoder.tokenizer |
|
|
| def setup_peft(self, peft_type="lora", **kwargs): |
| if peft_type == "freeze": |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| else: |
| if peft_type == "lora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| inference_mode=False, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["query", "value"], |
| ) |
| elif peft_type == "ia3": |
| peft_config = IA3Config( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| target_modules=["query", "value", "dense"], |
| feedforward_modules=["dense"], |
| ) |
| elif peft_type == "dora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| inference_mode=False, |
| use_dora=True, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=["query", "value"], |
| ) |
| elif peft_type == "adalora": |
| lora_r, lora_alpha, lora_dropout = kwargs.get("lora_r", 8), \ |
| kwargs.get("lora_alpha", 16), \ |
| kwargs.get("lora_dropout", 0.1) |
| peft_config = AdaLoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_r=4, |
| init_r=12, |
| beta1=0.85, beta2=0.85, |
| tinit=200, |
| tfinal=1000, |
| deltaT=10, |
| target_modules=["query", "value"], |
| ) |
| protein_encoder = self.model.protein_encoder.model.esm |
| structure_encoder = self.model.structure_encoder.model.esm |
| protein_encoder = get_peft_model(protein_encoder, peft_config) |
| structure_encoder = get_peft_model(structure_encoder, peft_config) |
| self.model.protein_encoder.model.esm = protein_encoder |
| self.model.structure_encoder.model.esm = structure_encoder |
| |
|
|
| def construct_batch(self, batch): |
| names, seqs, structs, masks, labels = [], [], [], [], [] |
| max_length_batch = min( |
| max([len(sample['seq']) for sample in batch]) + 2, |
| self.max_length + 2 |
| ) |
| for sample in batch: |
| seq = sample['seq']; X = sample['X'] |
| if X is None: continue |
| N, CA, C, CB = X[:,0], X[:,1], X[:,2], X[:,3] |
| states = self.encoder_3di.encode_atoms( |
| ca=CA.float().cpu().numpy(), |
| cb=CB.float().cpu().numpy(), |
| n=N.float().cpu().numpy(), |
| c=C.float().cpu().numpy(), |
| ) |
| struct_seq = self.encoder_3di.build_sequence(states).lower() |
| |
| mask = torch.zeros(max_length_batch, dtype=torch.bool) |
| mask[:len(seq)+2] = True |
| names.append(sample['name']) |
| seqs.append(seq) |
| structs.append(struct_seq) |
| masks.append(mask) |
| labels.append(sample['label']) |
| return { |
| 'name': names, |
| 'seq': seqs, |
| 'struct': structs, |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| |
| if return_logits: |
| seq_tokens = self.model.protein_encoder.tokenizer.batch_encode_plus(batch['seq'], return_tensors="pt", padding=True) |
| seq_tokens["input_ids"], seq_tokens["attention_mask"] = seq_tokens["input_ids"].to(self.device), seq_tokens["attention_mask"].to(self.device) |
| seq_logits = self.model.protein_encoder(seq_tokens, get_mask_logits=True)[-1] |
|
|
| |
| |
| |
| return seq_logits |
| |
| prot = self.model.get_protein_repr(batch['seq']) |
| struct = self.model.get_structure_repr(batch['struct']) |
| emb = torch.cat([prot, struct], dim=-1) |
| mask = batch['attention_mask'] |
| ends = mask.sum(dim=-1) - 1; start = 0 |
| if post_process: |
| return self.post_process_cpu(batch, emb, mask, start, ends, task_type) |
| return emb |
|
|
| |
| class SaPortModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, model_path="SaPort/ckpt", sequence_only=False): |
| super().__init__(device) |
| from transformers import EsmTokenizer, EsmForMaskedLM |
| self.encoder_3di = mini3di.Encoder() |
| self.tokenizer = EsmTokenizer.from_pretrained(f'{MODEL_ZOOM_PATH}/{model_path}') |
| self.model = EsmForMaskedLM.from_pretrained(f'{MODEL_ZOOM_PATH}/{model_path}').to(self.device) |
| self.max_length = max_length |
| self.sequence_only = sequence_only |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| names, seqs, masks, labels = [], [], [], [] |
| max_len = min( |
| max([len(s['seq']) for s in batch]) + 2, |
| self.max_length + 2 |
| ) |
| for sample in batch: |
| seq, X = sample['seq'], sample['X'] |
| N, CA, C, CB = X[:,0], X[:,1], X[:,2], X[:,3] |
| states = self.encoder_3di.encode_atoms( |
| ca=CA.float().cpu().numpy(), |
| cb=CB.float().cpu().numpy(), |
| n=N.float().cpu().numpy(), |
| c=C.float().cpu().numpy(), |
| ) |
| struct_seq = self.encoder_3di.build_sequence(states).lower() |
| merged = ''.join(a + b.lower() for a, b in zip(seq, struct_seq)) |
| tid = torch.tensor(self.tokenizer(merged, return_tensors='pt').input_ids[0]) |
| mask = torch.zeros(max_len, dtype=torch.bool) |
| mask[:len(tid)] = True |
| tid = self.pad_data(tid, dim=0, max_length=max_len) |
| mask = self.pad_data(mask, dim=0, max_length=max_len) |
|
|
| names.append(sample['name']) |
| seqs.append(tid) |
| masks.append(mask) |
| labels.append(torch.tensor(sample['label'])) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels, |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| seq, mask = batch['seq'], batch['attention_mask'] |
| if return_logits: |
| out = self.model(input_ids=seq, attention_mask=mask, return_dict=True) |
| return out.logits |
| |
| out = self.model.esm(input_ids=seq, attention_mask=mask, return_dict=True) |
| emb = out.last_hidden_state |
| start = 0 |
| ends = mask.sum(dim=-1) - 1 |
| if post_process: |
| return self.post_process_cpu(batch, emb, mask, start, ends, task_type) |
| return emb |
|
|
|
|
| |
| class VenusPLMModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| config = TransformerConfig.from_pretrained(MODEL_ZOOM_PATH + '/venusplm', attn_impl="sdpa") |
| self.model = TransformerForMaskedLM.from_pretrained(MODEL_ZOOM_PATH + '/venusplm', config=config).to(self.device) |
| self.tokenizer = VPLMTokenizer.from_pretrained(MODEL_ZOOM_PATH + '/venusplm') |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| names, seqs, masks, labels = [], [], [], [] |
| max_len = min( |
| max([len(s['seq']) for s in batch]) + 2, |
| self.max_length + 2 |
| ) |
| for sample in batch: |
| seq = sample['seq'] |
| seq_tokens = torch.tensor(self.tokenizer.encode(seq))[:self.max_length] |
| attention_mask = torch.zeros(max_len, dtype=torch.bool) |
| attention_mask[:len(seq_tokens)] = True |
| seq_tokens = self.pad_data(seq_tokens, dim=0, max_length=max_len) |
| seq_tokens = self.pad_data(seq_tokens, dim=0, max_length=max_len) |
|
|
| names.append(sample['name']) |
| seqs.append(seq_tokens) |
| masks.append(attention_mask) |
| labels.append(torch.tensor(sample['label'])) |
|
|
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| out = self.model(input_ids=batch['seq'], attention_mask=batch['attention_mask'], output_hidden_states=True) |
| if return_logits: |
| return out.logits |
| emb = out.hidden_states[-1] |
| start = 0 |
| ends = batch['attention_mask'].sum(dim=-1) - 1 |
| if post_process: |
| return self.post_process_cpu(batch, emb, batch['attention_mask'], start, ends, task_type) |
| return emb |
| |
|
|
| |
| class ProSST2048Model(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| from model_zoom.ProSST.prosst.structure.quantizer import PdbQuantizer |
| weight_path = f"{MODEL_ZOOM_PATH}/ProSST/prosst_2048_weight" |
| self.tokenizer = AutoTokenizer.from_pretrained(weight_path, trust_remote_code=True) |
| self.quantizer = PdbQuantizer(structure_vocab_size=2048) |
| self.model = AutoModelForMaskedLM.from_pretrained(weight_path, trust_remote_code=True).to(self.device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = min(max([len(s['seq']) for s in batch]) + 2, self.max_length + 2) |
| names, seqs, xs, masks, labels = [], [], [], [], [] |
| for sample in batch: |
| seq = sample['seq'] |
| pdb_path = sample['pdb_path'] |
| pdb_name = os.path.basename(pdb_path) |
| seq_tokens = torch.tensor(self.tokenizer.encode(seq))[:self.max_length] |
| struct = self.quantizer(pdb_path, return_residue_seq=False)['2048'][pdb_name]["struct"] |
| struct_seq = [i + 3 for i in struct] |
| struct_seq = [1] + struct_seq + [2] |
| struct_tokens = torch.tensor(struct_seq)[:self.max_length] |
| attention_mask = torch.zeros(max_len, dtype=torch.bool) |
| attention_mask[:len(seq_tokens)] = True |
| seq_tokens = self.pad_data(seq_tokens, dim=0, max_length=max_len) |
| struct_tokens = self.pad_data(struct_tokens, dim=0, max_length=max_len) |
| names.append(sample['name']) |
| seqs.append(seq_tokens) |
| xs.append(struct_tokens) |
| masks.append(attention_mask) |
| labels.append(torch.tensor(sample['label'])) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'X': torch.stack(xs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', **kwargs): |
| outputs = self.model( |
| input_ids=batch['seq'], |
| attention_mask=batch['attention_mask'], |
| output_hidden_states=True, |
| ss_input_ids=batch['X'] |
| ) |
| embeddings = outputs.hidden_states[-1] |
| ends = batch['attention_mask'].sum(dim=-1) - 1 |
| start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, embeddings, batch['attention_mask'], start, ends, task_type) |
| return embeddings |
|
|
|
|
| |
| class ProtT5(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| from transformers import T5Tokenizer, T5EncoderModel |
| weight_path = f"{MODEL_ZOOM_PATH}/ProtT5" |
| self.tokenizer = T5Tokenizer.from_pretrained(weight_path, do_lower_case=False) |
| self.model = T5EncoderModel.from_pretrained(weight_path).to(device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = min(max([len(s['seq']) for s in batch]) + 2, self.max_length + 2) |
| |
| names, seqs, masks, labels = [], [], [], [] |
| for sample in batch: |
| seq = " ".join(list(sample['seq'][:self.max_length])) |
| seq_tokens = torch.tensor(self.tokenizer.encode(seq, add_special_tokens=False))[:self.max_length] |
| attention_mask = torch.zeros(max_len, dtype=torch.bool) |
| attention_mask[:len(seq_tokens)] = True |
| seq_tokens = self.pad_data(seq_tokens, dim=0, max_length=max_len) |
| names.append(sample['name']) |
| seqs.append(seq_tokens) |
| masks.append(attention_mask) |
| labels.append(torch.tensor(sample['label'])) |
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| embedding_repr = self.model( |
| input_ids=batch['seq'], |
| attention_mask=batch['attention_mask'], |
| ) |
| embeddings = embedding_repr.last_hidden_state |
| if return_logits: |
| return embeddings |
| ends = batch['attention_mask'].sum(dim=-1) - 1 |
| start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, embeddings, batch['attention_mask'], start, ends, task_type) |
| return embeddings |
|
|
| class DPLMModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, model_path="dplm_650m", **kwargs): |
| super().__init__(device) |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| self.model = AutoModelForMaskedLM.from_pretrained(f"{MODEL_ZOOM_PATH}/esm2_650m").to(self.device) |
| params = torch.load(f'{MODEL_ZOOM_PATH}/{model_path}/pytorch_model.bin') |
| self.model.load_state_dict(params, strict=True) |
| self.tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/esm2_650m") |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| MAXLEN = self.max_length |
| max_length_batch = min( |
| max([len(sample['seq']) for sample in batch]) + 2, |
| self.max_length + 2 |
| ) |
| result = { |
| 'name': [], |
| 'seq': [], |
| 'attention_mask': [], |
| 'label': [] |
| } |
| for sample in batch: |
| seq_token = torch.tensor(self.tokenizer.encode(sample['seq']))[:MAXLEN] |
| attention_mask = torch.zeros(max_length_batch) |
| attention_mask[:len(seq_token)] = 1 |
| seq_token = self.pad_data(seq_token, dim=0, max_length=max_length_batch) |
| result['name'].append(sample['name']) |
| result['seq'].append(seq_token) |
| result['attention_mask'].append(attention_mask) |
| result['label'].append(sample['label']) |
|
|
| result['seq'] = torch.stack(result['seq'], dim=0).to(self.device) |
| result['attention_mask'] = torch.stack(result['attention_mask'], dim=0).to(self.device) |
| |
| return result |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_prob=False, return_logits=False, **kwargs): |
| attention_mask = batch['attention_mask'] |
| outputs = self.model.esm( |
| batch['seq'], |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| if return_prob or return_logits: |
| if return_prob and return_logits: return_logits = False |
| logits = self.model.lm_head(outputs.last_hidden_state) |
| if return_logits: |
| return logits |
| probs = F.softmax(logits, dim=-1) |
| return probs |
| |
| embeddings = outputs.last_hidden_state |
| ends = attention_mask.sum(dim=-1)-1 |
| start = 1 |
| if post_process: |
| result = self.post_process_cpu(batch, embeddings, attention_mask, start, ends, task_type=task_type) |
| else: |
| result = embeddings |
| return result |
|
|
| class OntoProteinModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, **kwargs): |
| super().__init__(device) |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| self.tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_ZOOM_PATH}/OntoProtein") |
| self.model = AutoModelForMaskedLM.from_pretrained(f"{MODEL_ZOOM_PATH}/OntoProtein").to(self.device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| import re |
| MAXLEN = self.max_length |
| max_length_batch = min( |
| max([len(sample['seq']) for sample in batch]) + 2, |
| MAXLEN + 2 |
| ) |
| result = { |
| 'name': [], |
| 'seq': [], |
| 'attention_mask': [], |
| 'token_type_ids': [], |
| 'label': [] |
| } |
| for sample in batch: |
| sequence_Example = ' '.join(sample['seq']) |
| sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example) |
| encoded_input = self.tokenizer(sequence_Example, return_tensors='pt') |
| |
| input_ids = self.pad_data(encoded_input['input_ids'][0], dim=0, max_length=max_length_batch) |
| attention_mask = self.pad_data(encoded_input['attention_mask'][0], dim=0, max_length=max_length_batch) |
| token_type_ids = self.pad_data(encoded_input['token_type_ids'][0], dim=0, max_length=max_length_batch) |
| |
| |
| result['name'].append(sample['name']) |
| result['seq'].append(input_ids) |
| result['attention_mask'].append(attention_mask) |
| result['token_type_ids'].append(token_type_ids) |
| result['label'].append(sample['label']) |
|
|
| result['seq'] = torch.stack(result['seq'], dim=0).to(self.device) |
| result['attention_mask'] = torch.stack(result['attention_mask'], dim=0).to(self.device) |
| result['token_type_ids'] = torch.stack(result['token_type_ids'], dim=0).to(self.device) |
| |
| return result |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_prob=False, **kwargs): |
| output = self.model.bert( |
| input_ids=batch['seq'], |
| attention_mask=batch['attention_mask'], |
| token_type_ids=batch['token_type_ids'], |
| ) |
| if return_prob: |
| logits = self.model.cls(output.last_hidden_state) |
| probs = F.softmax(logits, dim=-1) |
| return probs |
| |
| attention_mask = batch['attention_mask'] |
| embeddings = output.last_hidden_state |
| ends = attention_mask.sum(dim=-1)-1 |
| start = 1 |
| if post_process: |
| result = self.post_process_cpu(batch, embeddings, attention_mask, start, ends, task_type=task_type) |
| else: |
| result = embeddings |
| return result |
|
|
|
|
| class ANKHBase(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022): |
| super().__init__(device) |
| from transformers import AutoTokenizer, T5EncoderModel |
| weight_path = f"{MODEL_ZOOM_PATH}/ankh_base" |
| self.tokenizer = AutoTokenizer.from_pretrained(weight_path) |
| self.model = T5EncoderModel.from_pretrained(weight_path).to(device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| max_len = min(max([len(s['seq']) for s in batch]) + 2, self.max_length + 2) |
| names, seqs, masks, labels = [], [], [], [] |
| for sample in batch: |
| seq = sample['seq'][:1022] |
| seq_tokens = torch.tensor(self.tokenizer.encode(seq, add_special_tokens=False))[:self.max_length] |
| attention_mask = torch.zeros(max_len, dtype=torch.bool) |
| attention_mask[:len(seq_tokens)] = True |
| seq_tokens = self.pad_data(seq_tokens, dim=0, max_length=max_len) |
| names.append(sample['name']) |
| seqs.append(seq_tokens) |
| masks.append(attention_mask) |
| labels.append(torch.tensor(sample['label'])) |
|
|
| return { |
| 'name': names, |
| 'seq': torch.stack(seqs).to(self.device), |
| 'attention_mask': torch.stack(masks).to(self.device), |
| 'label': labels |
| } |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_logits=False, **kwargs): |
| embedding_repr = self.model( |
| input_ids=batch['seq'], |
| attention_mask=batch['attention_mask'], |
| ) |
| embeddings = embedding_repr.last_hidden_state |
| if return_logits: |
| return embeddings |
| ends = batch['attention_mask'].sum(dim=-1) - 1 |
| start = 1 |
| if post_process: |
| return self.post_process_cpu(batch, embeddings, batch['attention_mask'], start, ends, task_type) |
| return embeddings |
|
|
| class PGLMModel(BaseProteinModel, UtilsModel): |
| def __init__(self, device, max_length=1022, model_path="proteinglm-1b-mlm", **kwargs): |
| super().__init__(device) |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| f"{MODEL_ZOOM_PATH}/{model_path}", |
| trust_remote_code=True, |
| local_files_only=True, |
| use_fast=True |
| ) |
| self.model = AutoModelForMaskedLM.from_pretrained( |
| f"{MODEL_ZOOM_PATH}/{model_path}", |
| trust_remote_code=True, |
| local_files_only=True |
| ).to(self.device) |
| self.max_length = max_length |
|
|
| def get_tokenizer(self): |
| return self.tokenizer |
|
|
| def construct_batch(self, batch): |
| MAXLEN = self.max_length |
| max_length_batch = min( |
| max([len(self.tokenizer.encode(sample['seq'], add_special_tokens=False)) for sample in batch]) + 2, |
| self.max_length + 2 |
| ) |
| result = { |
| 'name': [], |
| 'seq': [], |
| 'attention_mask': [], |
| 'label': [] |
| } |
| for sample in batch: |
| output = self.tokenizer(sample['seq'], add_special_tokens=True, return_tensors='pt') |
| seq_token = output['input_ids'][0] |
| attention_mask = torch.zeros(max_length_batch) |
| attention_mask[:len(seq_token)] = 1 |
| seq_token = self.pad_data(seq_token, dim=0, max_length=max_length_batch) |
| result['name'].append(sample['name']) |
| result['seq'].append(seq_token) |
| result['attention_mask'].append(attention_mask) |
| result['label'].append(sample['label']) |
|
|
| result['seq'] = torch.stack(result['seq'], dim=0).to(self.device) |
| result['attention_mask'] = torch.stack(result['attention_mask'], dim=0).to(self.device) |
| |
| return result |
|
|
| def forward(self, batch, post_process=True, task_type='binary_classification', return_prob=False, return_logits=False, **kwargs): |
| attention_mask = batch['attention_mask'] |
| outputs = self.model( |
| batch['seq'], |
| attention_mask=batch['attention_mask'], |
| output_hidden_states=True, return_last_hidden_state=True |
| ) |
| if return_prob or return_logits: |
| if return_prob and return_logits: return_logits = False |
| if return_logits: |
| return outputs.logits |
| probs = F.softmax(logits, dim=-1) |
| return probs |
| |
| embeddings = outputs.hidden_states.permute(1,0,2) |
| ends = attention_mask.sum(dim=-1)-1 |
| start = 1 |
| if post_process: |
| result = self.post_process_cpu(batch, embeddings, attention_mask, start, ends, task_type=task_type) |
| else: |
| result = embeddings |
| return result |
|
|
| if __name__ == "__main__": |
| import torch |
| import sys; sys.path.append("/nfs_beijing/kubeflow-user/wanghao/workspace/ai4sci/protein_benchmark_new/protein_benchmark") |
| sys.path.append('/nfs_beijing/kubeflow-user/wanghao/workspace/ai4sci/protein_benchmark_new/protein_benchmark/model_zoom') |
| from src.data.esm.sdk.api import ESMProtein |
| model_name = "saprot" |
| |
| pdb_path = "/nfs_beijing/kubeflow-user/wanghao/workspace/ai4sci/protein_benchmark_new/protein_benchmark/datasets/DMS_ProteinGym_substitutions/ProteinGym_AF2_structures/A0A1I9GEU1_NEIME.pdb" |
| structure = ESMProtein.from_pdb(pdb_path) |
| sequence = structure.sequence |
| coordinates = structure.coordinates |
| print(f"length of sequence: {len(sequence)}") |
| ori_batch = [ |
| { |
| "seq": sequence, |
| "X": coordinates, |
| "name": "unknown", |
| "label": 1.0 |
| } |
| ] |
| |
| |
| if model_name == "esm2_650m": |
| model = ESM2Model(device="cuda:0") |
| if model_name == "esmc_600m": |
| model = ESMC600MModel(device="cuda:0") |
| if model_name == "esm3_1.4b": |
| model = ESM3Model(device="cuda:0") |
| if model_name == "venusplm": |
| model = VenusPLMModel(device="cuda:0") |
| if model_name == "protst": |
| model = ProSTModel(device="cuda:0") |
| if model_name == "prostt5": |
| model = ProstT5Model(device="cuda:0") |
| if model_name == "protrek": |
| model = ProTrekModel(device="cuda:0") |
| if model_name == "saprot": |
| model = SaPortModel(device="cuda:0") |
| if model_name == "prott5": |
| model = ProtT5(device="cuda:0") |
| if model_name == "dplm": |
| model = DPLMModel(device="cuda:0") |
| if model_name == "dplm_150m": |
| model = DPLMModel(device="cuda:0", model_path="dplm_150m") |
| if model_name == "dplm_3b": |
| model = DPLMModel(device="cuda:0", model_path="dplm_3b") |
| if model_name == "dplm": |
| model = DPLMModel(device="cuda:0") |
| if model_name == "pglm": |
| model = PGLMModel(device="cuda:0") |
|
|
| |
| |
| |
| input_batch = model.construct_batch(ori_batch) |
| logits = model.forward(batch=input_batch, return_logits=True) |
| print(logits.shape) |
|
|