徐俊德
init
c525dff
raw
history blame
9.59 kB
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
import pickle
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
class GLU(nn.Module):
def __init__(self,hidden_size):
super().__init__()
self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
self.norm1 = nn.LayerNorm(hidden_size)
self.act1 = nn.GELU()
self.act2 = nn.functional.silu
self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)
def forward(self,x):
x = self.linear_proj(x)
x = self.act1(self.norm1(x))
x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
x = self.dense_4h_to_h(x)
return x
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return nn.functional.silu(x[0]) * x[1]
class GLU_new(nn.Module):
def __init__(self,hidden_size, dropout=0.1):
super().__init__()
intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
intermediate_size = 1280
self.act = swiglu
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
self.dropout = nn.Dropout(p=dropout)
def forward(self,x):
x = self.dense_h_to_4h(x)
x = self.act(x)
x = self.dense_4h_to_h(x)
x = self.dropout(x)
return x
n_queries = 32
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
from einops import rearrange, repeat
def get_1d_sincos_pos_embed(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
def __init__(
self,
kv_dim,
embed_dim,
num_heads=8,
n_queries=64,
max_seqlen=1024,
perceiver_resampler_positional_emb=True,
use_GLU=False,
bos_init=False,
dropout=0.0
):
super().__init__()
self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb
if self.perceiver_resampler_positional_emb:
assert n_queries <= max_seqlen
self.stride = max_seqlen // n_queries
# self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
# nn.init.trunc_normal_(self.nan_emb, std=.02)
pos = np.arange(max_seqlen, dtype=np.float32)
self.register_buffer(
"pos_embed",
torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
)
self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
if bos_init:
self.latents.load('')
else:
nn.init.trunc_normal_(self.latents, std=1e-3)
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
self.ln_q = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(embed_dim)
self.ln_post = nn.LayerNorm(embed_dim)
if use_GLU:
print('GLU *********************************')
self.proj = GLU_new(embed_dim, dropout=dropout)
else:
self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=1e-3)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, struc_x):
"""
Args:
x (torch.Tensor): protein structure features
shape (B, L, C)
Returns:
shape (B, n, C) where n is self.num_latents
"""
x = struc_x["encoder_out"]
mask = struc_x["encoder_padding_mask"]
nan_mask = torch.isnan(x)
if nan_mask.any():
x = x.masked_fill(nan_mask, 0.0)
# nan_mask = nan_mask.sum(dim=-1).bool()
# x[nan_mask] += self.nan_emb
x = self.kv_proj(x)
x = self.ln_kv(x)
b, seqlen = x.shape[:2]
latents = self.ln_q(self.latents)
if self.perceiver_resampler_positional_emb:
# TODO: interpolate
latents = latents + self.pos_embed[::self.stride].contiguous()
pos_emb = self.pos_embed[:seqlen].unsqueeze(0)
x = x + pos_emb.contiguous()
# blocks
latents = repeat(latents, "n d -> b n d", b=b)
out = self.attn(latents, x, x, key_padding_mask=~mask)[0]
out = self.ln_post(out)
out = self.proj(out)
return out
class StructureTransformer(nn.Module):
def __init__(
self,
width: int = 640,
n_queries: int = 32,
output_dim: int = 4096,
embedding_keys=set(["mpnn_emb"]),
max_seqlen: int=1024,
num_heads: int=8,
structure_emb_path_prefix='structure_emb',
**kwargs
):
super().__init__()
self.structure_emb_path_prefix = structure_emb_path_prefix
# self.transformer = None # replace None with a pretrained strucure encoder
self.embedding_keys = embedding_keys
self.max_seqlen = max_seqlen
self.width = width
self.n_queries = n_queries
self.attn_pool = Resampler(
embed_dim=output_dim,
kv_dim=width,
n_queries=n_queries,
max_seqlen=max_seqlen,
num_heads=num_heads,
**kwargs
)
def prepare_structure(self, sample):
emb_pad = torch.zeros((self.max_seqlen, self.width))
emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample:
mask = sample["pifold_mask"]
pifold_emb = sample["pifold_emb"]
new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan"))
new_pifold_emb[mask > 0] = pifold_emb
sample["pifold_emb"] = new_pifold_emb
### domians ###
emb = []
for ek in self.embedding_keys:
if ek in sample:
if isinstance( sample[ek], List):
emb.append(torch.cat(sample[ek]))
else:
emb.append(sample[ek])
# emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
emb = torch.cat(emb, dim=-1)
emb_pad[:len(emb)] = emb
emb_mask[:len(emb)] = 1
return emb_pad, emb_mask
def forward(self, x):
# x = self.transformer(x)
x = self.attn_pool(x)
return x
def encode(self, structure_paths: List[str]):
structure_embs = []
structure_mask = []
for structure_path in structure_paths:
structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0]
structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
if not os.path.exists(structure_path):
print('no structure found')
return None
with open(structure_path, 'rb') as f:
structure, struc_mask = self.prepare_structure(pickle.load(f))
structure_embs.append(structure)
structure_mask.append(struc_mask)
structure_embs = torch.stack(structure_embs, dim=0).to(
device=next(self.attn_pool.parameters()).device,
dtype=next(self.attn_pool.parameters()).dtype)
structure_mask = torch.stack(structure_mask, dim=0).to(
device=next(self.attn_pool.parameters()).device)
return self({
'encoder_out': structure_embs,
'encoder_padding_mask': structure_mask
})