segment_borzoi / segment_borzoi.py
Yanisadel's picture
Upload SegmentBorzoi
be84f90 verified
raw
history blame
8.79 kB
from typing import Any, Dict, List
import borzoi_pytorch
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum
from transformers import PretrainedConfig, PreTrainedModel
from genomics_research.segmentnt.porting_to_pytorch.layers.segmentation_head import (
TorchUNetHead,
)
FEATURES = [
"protein_coding_gene",
"lncRNA",
"exon",
"intron",
"splice_donor",
"splice_acceptor",
"5UTR",
"3UTR",
"CTCF-bound",
"polyA_signal",
"enhancer_Tissue_specific",
"enhancer_Tissue_invariant",
"promoter_Tissue_specific",
"promoter_Tissue_invariant",
]
class SegmentBorzoiConfig(PretrainedConfig):
model_type = "segment_borzoi"
def __init__(
self,
features: List[str] = FEATURES,
embed_dim: int = 1536,
dim_divisible_by: int = 32,
attention_dim_key: int = 64,
num_attention_heads: int = 8,
num_rel_pos_features: int = 32,
**kwargs: Dict[str, Any],
):
self.features = features
self.embed_dim = embed_dim
self.dim_divisible_by = dim_divisible_by
self.attention_dim_key = attention_dim_key
self.num_attention_heads = num_attention_heads
self.num_rel_pos_features = num_rel_pos_features
super().__init__(**kwargs)
class SegmentBorzoi(PreTrainedModel):
config_class = SegmentBorzoiConfig
def __init__(self, config: SegmentBorzoiConfig):
super().__init__(config=config)
borzoi = borzoi_pytorch.Borzoi.from_pretrained("johahi/borzoi-replicate-0")
# Stem
self.stem = borzoi.conv_dna
# Conv tower
self.res_tower = borzoi.res_tower
self.unet1 = borzoi.unet1
self._max_pool = borzoi._max_pool
# Transformer tower
self.transformer = borzoi.transformer
# UNet convolution layers
self.horizontal_conv1 = borzoi.horizontal_conv1
self.horizontal_conv0 = borzoi.horizontal_conv0
self.upsampling_unet1 = borzoi.upsampling_unet1
self.upsampling_unet0 = borzoi.upsampling_unet0
self.separable1 = borzoi.separable1
self.separable0 = borzoi.separable0
# Target length crop
self.crop = borzoi.crop
# Final convolution block
self.final_joined_convs = borzoi.final_joined_convs
self.unet_head = TorchUNetHead(
features=config.features,
embed_dimension=config.embed_dim,
nucl_per_token=config.dim_divisible_by,
remove_cls_token=False,
)
# Correct transformer
for layer in self.transformer:
layer[0].fn[1] = BorzoiAttentionLayer( # type: ignore
config.embed_dim,
heads=config.num_attention_heads,
dim_key=config.attention_dim_key,
dim_value=config.embed_dim // config.num_attention_heads,
dropout=0.05,
pos_dropout=0.01,
num_rel_pos_features=config.num_rel_pos_features,
)
# Correct conv layer in downsample block
self.unet_head.unet.downsample_blocks[0].conv_layers[0] = nn.Conv1d(
in_channels=1920, out_channels=1536, kernel_size=3, stride=1, padding=1
)
# Correct bias in separable layers
self.separable1.conv_layer[1].bias = None
self.separable0.conv_layer[1].bias = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Stem
x = x.transpose(1, 2)
x = self.stem(x)
# Conv tower
x_unet0 = self.res_tower(x)
x_unet1 = self.unet1(x_unet0)
x = self._max_pool(x_unet1)
# Transformer tower
x = x.permute(0, 2, 1)
x = self.transformer(x)
x = x.permute(0, 2, 1)
# UNet conv
x_unet1 = self.horizontal_conv1(x_unet1)
x_unet0 = self.horizontal_conv0(x_unet0)
# UNet upsampling and separable convolutions
x = self.upsampling_unet1(x)
x += x_unet1
x = self.separable1(x)
x = self.upsampling_unet0(x)
x += x_unet0
x = self.separable0(x)
# Target length crop
x = self.crop(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
# Final convolution block
x = self.final_joined_convs(x)
x = self.unet_head(x)
return x
# Define custom attention layer for PyTorch model because Attention layer from the
# imported model is not the same (the positional embeddings are not the same)
def _prepend_dims(tensor: torch.Tensor, num_dims: int) -> torch.Tensor:
"""Prepends dimensions to match the required shape."""
for _ in range(num_dims - tensor.dim()):
tensor = tensor.unsqueeze(0)
return tensor
def get_positional_features_central_mask_borzoi(
positions: torch.Tensor, feature_size: int, seq_length: int
) -> torch.Tensor:
"""Positional features using a central mask (allow only central features)."""
pow_rate = torch.exp(torch.log(torch.tensor(seq_length + 1.0)) / feature_size)
center_widths = torch.pow(pow_rate, torch.arange(1, feature_size + 1).float()) - 1
center_widths = _prepend_dims(center_widths, positions.ndim)
outputs = (center_widths > torch.abs(positions).unsqueeze(-1)).float()
return outputs
def get_positional_embed_borzoi(seq_len: int, feature_size: int) -> torch.Tensor:
"""
Compute positional embedding for Borzoi. Note that it is different than the one
used in Enformer.
"""
distances = torch.arange(-seq_len + 1, seq_len)
num_components = 2
if (feature_size % num_components) != 0:
raise ValueError(
f"feature size is not divisible by number of components ({num_components})"
)
num_basis_per_class = feature_size // num_components
embeddings = []
embeddings.append(
get_positional_features_central_mask_borzoi(
distances, num_basis_per_class, seq_len
)
)
embeddings = torch.cat(embeddings, dim=-1)
embeddings = torch.cat(
(embeddings, torch.sign(distances).unsqueeze(-1) * embeddings), dim=-1
)
return embeddings
def relative_shift(x: torch.Tensor) -> torch.Tensor:
to_pad = torch.zeros_like(x[..., :1])
x = torch.cat((to_pad, x), dim=-1)
_, h, t1, t2 = x.shape
x = x.reshape(-1, h, t2, t1) # noqa: FKA100
x = x[:, :, 1:, :]
x = x.reshape(-1, h, t1, t2 - 1) # noqa: FKA100
return x[..., : ((t2 + 1) // 2)]
class BorzoiAttentionLayer(nn.Module):
def __init__( # type: ignore
self,
dim,
*,
num_rel_pos_features,
heads=8,
dim_key=64,
dim_value=64,
dropout=0.0,
pos_dropout=0.0,
) -> None:
super().__init__()
self.scale = dim_key**-0.5
self.heads = heads
self.to_q = nn.Linear(dim, dim_key * heads, bias=False)
self.to_k = nn.Linear(dim, dim_key * heads, bias=False)
self.to_v = nn.Linear(dim, dim_value * heads, bias=False)
self.to_out = nn.Linear(dim_value * heads, dim)
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
self.num_rel_pos_features = num_rel_pos_features
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias=False)
self.rel_content_bias = nn.Parameter(
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
)
self.rel_pos_bias = nn.Parameter(
torch.randn(1, heads, 1, dim_key) # noqa: FKA100
)
# dropouts
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n, h = x.shape[-2], self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q, k, v = map( # noqa
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), # type: ignore
(q, k, v),
)
q = q * self.scale
content_logits = einsum(
"b h i d, b h j d -> b h i j", q + self.rel_content_bias, k
)
positions = get_positional_embed_borzoi(n, self.num_rel_pos_features)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)
rel_k = rearrange(rel_k, "n (h d) -> h n d", h=h)
rel_logits = einsum("b h i d, h j d -> b h i j", q + self.rel_pos_bias, rel_k)
rel_logits = relative_shift(rel_logits)
logits = content_logits + rel_logits
attn = logits.softmax(dim=-1)
attn = self.attn_dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)