|
import math |
|
from functools import partial |
|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from mamba_ssm.modules.mamba_simple import Mamba, Block |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
def seq_to_oh(seq): |
|
oh = np.zeros((len(seq), 4), dtype=int) |
|
for i, base in enumerate(seq): |
|
if base == 'A': |
|
oh[i, 0] = 1 |
|
elif base == 'C': |
|
oh[i, 1] = 1 |
|
elif base == 'G': |
|
oh[i, 2] = 1 |
|
elif base == 'T': |
|
oh[i, 3] = 1 |
|
return oh |
|
|
|
def create_block( |
|
d_model, |
|
ssm_cfg=None, |
|
norm_epsilon=1e-5, |
|
residual_in_fp32=False, |
|
fused_add_norm=False, |
|
layer_idx=None, |
|
device=None, |
|
dtype=None, |
|
): |
|
if ssm_cfg is None: |
|
ssm_cfg = {} |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
mix_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) |
|
norm_cls = partial(nn.LayerNorm, eps=norm_epsilon, **factory_kwargs) |
|
block = Block( |
|
d_model, |
|
mix_cls, |
|
norm_cls=norm_cls, |
|
fused_add_norm=fused_add_norm, |
|
residual_in_fp32=residual_in_fp32, |
|
) |
|
block.layer_idx = layer_idx |
|
return block |
|
|
|
|
|
class MixerModel( |
|
nn.Module, |
|
PyTorchModelHubMixin, |
|
): |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
n_layer: int, |
|
input_dim: int, |
|
ssm_cfg=None, |
|
norm_epsilon: float = 1e-5, |
|
rms_norm: bool = False, |
|
initializer_cfg=None, |
|
fused_add_norm=False, |
|
residual_in_fp32=False, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.residual_in_fp32 = residual_in_fp32 |
|
|
|
self.embedding = nn.Linear(input_dim, d_model, **factory_kwargs) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
create_block( |
|
d_model, |
|
ssm_cfg=ssm_cfg, |
|
norm_epsilon=norm_epsilon, |
|
residual_in_fp32=residual_in_fp32, |
|
fused_add_norm=fused_add_norm, |
|
layer_idx=i, |
|
**factory_kwargs, |
|
) |
|
for i in range(n_layer) |
|
] |
|
) |
|
|
|
self.norm_f = nn.LayerNorm(d_model, eps=norm_epsilon, **factory_kwargs) |
|
|
|
self.apply( |
|
partial( |
|
_init_weights, |
|
n_layer=n_layer, |
|
**(initializer_cfg if initializer_cfg is not None else {}), |
|
) |
|
) |
|
|
|
def forward(self, x, inference_params=None, channel_last=False): |
|
if not channel_last: |
|
x = x.transpose(1, 2) |
|
|
|
hidden_states = self.embedding(x) |
|
residual = None |
|
for layer in self.layers: |
|
hidden_states, residual = layer( |
|
hidden_states, residual, inference_params=inference_params |
|
) |
|
|
|
residual = (hidden_states + residual) if residual is not None else hidden_states |
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) |
|
|
|
hidden_states = hidden_states |
|
|
|
return hidden_states |
|
|
|
def representation( |
|
self, |
|
x: torch.Tensor, |
|
lengths: torch.Tensor, |
|
channel_last: bool = False, |
|
) -> torch.Tensor: |
|
"""Get global representation of input data. |
|
|
|
Args: |
|
x: Data to embed. Has shape (B x C x L) if not channel_last. |
|
lengths: Unpadded length of each data input. |
|
channel_last: Expects input of shape (B x L x C). |
|
|
|
Returns: |
|
Global representation vector of shape (B x H). |
|
""" |
|
out = self.forward(x, channel_last=channel_last) |
|
|
|
mean_tensor = mean_unpadded(out, lengths) |
|
return mean_tensor |
|
|
|
|
|
def mean_unpadded(x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
|
"""Take mean of tensor across second dimension without padding. |
|
|
|
Args: |
|
x: Tensor to take unpadded mean. Has shape (B x L x H). |
|
lengths: Tensor of unpadded lengths. Has shape (B) |
|
|
|
Returns: |
|
Mean tensor of shape (B x H). |
|
""" |
|
mask = torch.arange(x.size(1), device=x.device)[None, :] < lengths[:, None] |
|
masked_tensor = x * mask.unsqueeze(-1) |
|
sum_tensor = masked_tensor.sum(dim=1) |
|
mean_tensor = sum_tensor / lengths.unsqueeze(-1).float() |
|
|
|
return mean_tensor |
|
|
|
|
|
def _init_weights( |
|
module, |
|
n_layer, |
|
initializer_range=0.02, |
|
rescale_prenorm_residual=True, |
|
n_residuals_per_layer=1, |
|
): |
|
if isinstance(module, nn.Linear): |
|
if module.bias is not None: |
|
if not getattr(module.bias, "_no_reinit", False): |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, std=initializer_range) |
|
|
|
if rescale_prenorm_residual: |
|
for name, p in module.named_parameters(): |
|
if name in ["out_proj.weight", "fc2.weight"]: |
|
nn.init.kaiming_uniform_(p, a=math.sqrt(5)) |
|
with torch.no_grad(): |
|
p /= math.sqrt(n_residuals_per_layer * n_layer) |
|
|
|
def load_model(run_path: str, checkpoint_name: str) -> nn.Module: |
|
"""Load trained model located at specified path. |
|
|
|
Args: |
|
run_path: Path where run data is located. |
|
checkpoint_name: Name of model checkpoint to load. |
|
|
|
Returns: |
|
Model with loaded weights. |
|
""" |
|
model_config_path = os.path.join(run_path, "model_config.json") |
|
data_config_path = os.path.join(run_path, "data_config.json") |
|
|
|
with open(model_config_path, "r") as f: |
|
model_params = json.load(f) |
|
|
|
|
|
if "n_tracks" not in model_params: |
|
with open(data_config_path, "r") as f: |
|
data_params = json.load(f) |
|
n_tracks = data_params["n_tracks"] |
|
else: |
|
n_tracks = model_params["n_tracks"] |
|
|
|
model_path = os.path.join(run_path, checkpoint_name) |
|
|
|
model = MixerModel( |
|
d_model=model_params["ssm_model_dim"], |
|
n_layer=model_params["ssm_n_layers"], |
|
input_dim=n_tracks |
|
) |
|
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) |
|
|
|
state_dict = {} |
|
for k, v in checkpoint["state_dict"].items(): |
|
if k.startswith("model"): |
|
state_dict[k.lstrip("model")[1:]] = v |
|
|
|
model.load_state_dict(state_dict) |
|
return model |
|
|