# Copyright (c) Alibaba Cloud. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict import math import requests from io import BytesIO from functools import partial import pickle from typing import Callable, Optional, Sequence, Tuple, List import numpy as np import os import torch from torch import nn from torch.nn import functional as F from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode class GLU(nn.Module): def __init__(self,hidden_size): super().__init__() self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False) self.norm1 = nn.LayerNorm(hidden_size) self.act1 = nn.GELU() self.act2 = nn.functional.silu self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False) self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False) self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False) def forward(self,x): x = self.linear_proj(x) x = self.act1(self.norm1(x)) x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x) x = self.dense_4h_to_h(x) return x def swiglu(x): x = torch.chunk(x, 2, dim=-1) return nn.functional.silu(x[0]) * x[1] class GLU_new(nn.Module): def __init__(self,hidden_size, dropout=0.1): super().__init__() intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64 intermediate_size = 1280 self.act = swiglu self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False) self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False) self.dropout = nn.Dropout(p=dropout) def forward(self,x): x = self.dense_h_to_4h(x) x = self.act(x) x = self.dense_4h_to_h(x) x = self.dropout(x) return x n_queries = 32 def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M # return: M, C src_size = int(math.sqrt(abs_pos.size(0))) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos from einops import rearrange, repeat def get_1d_sincos_pos_embed(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class Resampler(nn.Module): def __init__( self, kv_dim, embed_dim, num_heads=8, n_queries=64, max_seqlen=1024, perceiver_resampler_positional_emb=True, use_GLU=False, bos_init=False, dropout=0.0 ): super().__init__() self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb if self.perceiver_resampler_positional_emb: assert n_queries <= max_seqlen self.stride = max_seqlen // n_queries # self.nan_emb = nn.Parameter(torch.randn(1, kv_dim)) # nn.init.trunc_normal_(self.nan_emb, std=.02) pos = np.arange(max_seqlen, dtype=np.float32) self.register_buffer( "pos_embed", torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float() ) self.latents = nn.Parameter(torch.randn(n_queries, embed_dim)) if bos_init: self.latents.load('') else: nn.init.trunc_normal_(self.latents, std=1e-3) self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout) self.ln_q = nn.LayerNorm(embed_dim) self.ln_kv = nn.LayerNorm(embed_dim) self.ln_post = nn.LayerNorm(embed_dim) if use_GLU: print('GLU *********************************') self.proj = GLU_new(embed_dim, dropout=dropout) else: self.proj = nn.Linear(embed_dim, embed_dim, bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=1e-3) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, struc_x): """ Args: x (torch.Tensor): protein structure features shape (B, L, C) Returns: shape (B, n, C) where n is self.num_latents """ x = struc_x["encoder_out"] mask = struc_x["encoder_padding_mask"] nan_mask = torch.isnan(x) if nan_mask.any(): x = x.masked_fill(nan_mask, 0.0) # nan_mask = nan_mask.sum(dim=-1).bool() # x[nan_mask] += self.nan_emb x = self.kv_proj(x) x = self.ln_kv(x) b, seqlen = x.shape[:2] latents = self.ln_q(self.latents) if self.perceiver_resampler_positional_emb: # TODO: interpolate latents = latents + self.pos_embed[::self.stride].contiguous() pos_emb = self.pos_embed[:seqlen].unsqueeze(0) x = x + pos_emb.contiguous() # blocks latents = repeat(latents, "n d -> b n d", b=b) out = self.attn(latents, x, x, key_padding_mask=~mask)[0] out = self.ln_post(out) out = self.proj(out) return out class StructureTransformer(nn.Module): def __init__( self, width: int = 640, n_queries: int = 32, output_dim: int = 4096, embedding_keys=set(["mpnn_emb"]), max_seqlen: int=1024, num_heads: int=8, structure_emb_path_prefix='structure_emb', **kwargs ): super().__init__() self.structure_emb_path_prefix = structure_emb_path_prefix # self.transformer = None # replace None with a pretrained strucure encoder self.embedding_keys = embedding_keys self.max_seqlen = max_seqlen self.width = width self.n_queries = n_queries self.attn_pool = Resampler( embed_dim=output_dim, kv_dim=width, n_queries=n_queries, max_seqlen=max_seqlen, num_heads=num_heads, **kwargs ) def prepare_structure(self, sample): emb_pad = torch.zeros((self.max_seqlen, self.width)) emb_mask = torch.zeros((self.max_seqlen), dtype=bool) if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample: mask = sample["pifold_mask"] pifold_emb = sample["pifold_emb"] new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan")) new_pifold_emb[mask > 0] = pifold_emb sample["pifold_emb"] = new_pifold_emb ### domians ### emb = [] for ek in self.embedding_keys: if ek in sample: if isinstance( sample[ek], List): emb.append(torch.cat(sample[ek])) else: emb.append(sample[ek]) # emb = [sample[ek] for ek in self.embedding_keys if ek in sample] emb = torch.cat(emb, dim=-1) emb_pad[:len(emb)] = emb emb_mask[:len(emb)] = 1 return emb_pad, emb_mask def forward(self, x): # x = self.transformer(x) x = self.attn_pool(x) return x def encode(self, structure_paths: List[str]): structure_embs = [] structure_mask = [] for structure_path in structure_paths: structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0] structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path)) if not os.path.exists(structure_path): print('no structure found') return None with open(structure_path, 'rb') as f: structure, struc_mask = self.prepare_structure(pickle.load(f)) structure_embs.append(structure) structure_mask.append(struc_mask) structure_embs = torch.stack(structure_embs, dim=0).to( device=next(self.attn_pool.parameters()).device, dtype=next(self.attn_pool.parameters()).dtype) structure_mask = torch.stack(structure_mask, dim=0).to( device=next(self.attn_pool.parameters()).device) return self({ 'encoder_out': structure_embs, 'encoder_padding_mask': structure_mask })