segment_enformer / segment_enformer.py
Yanisadel's picture
Upload SegmentEnformer
f7aa1ae verified
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from enformer_pytorch import Enformer
from transformers import PretrainedConfig, PreTrainedModel
def get_activation_fn(activation_name: str) -> Callable:
"""
Returns torch activation function
Args:
activation_name (str): Name of the activation function. Possible values are
'swish', 'relu', 'gelu', 'sin'
Raises:
ValueError: If activation_name is not supported
Returns:
Callable: Activation function
"""
if activation_name == "swish":
return nn.functional.silu # type: ignore
elif activation_name == "relu":
return nn.functional.relu # type: ignore
elif activation_name == "gelu":
return nn.functional.gelu # type: ignore
elif activation_name == "sin":
return torch.sin # type: ignore
else:
raise ValueError(f"Unsupported activation function: {activation_name}")
class TorchDownSample1D(nn.Module):
"""
Torch adaptation of DownSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2, padding=0)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
for conv_layer in self.conv_layers:
x = self.activation_fn(conv_layer(x))
hidden = x
x = self.avg_pool(hidden)
return x, hidden
class TorchUpSample1D(nn.Module):
"""
Torch adaptation of UpSample1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
interpolation_method: str = "nearest",
):
"""
Args:
input_channels: number of input channels.
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
interpolation_method: Method to be used for upsampling interpolation.
Should be one of "nearest", "linear", "cubic", "lanczos3", "lanczos5".
num_layers: number of convolution layers.
"""
super().__init__()
self.conv_transpose_layers = nn.ModuleList(
[
nn.ConvTranspose1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.interpolation_mode = interpolation_method
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv_layer in self.conv_transpose_layers:
x = self.activation_fn(conv_layer(x))
x = nn.functional.interpolate(
x,
scale_factor=2,
mode=self.interpolation_mode,
align_corners=False if self.interpolation_mode != "nearest" else None,
)
return x
class TorchFinalConv1D(nn.Module):
"""
Torch adaptation of FinalConv1D in trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
input_channels: int,
output_channels: int,
activation_fn: str = "swish",
num_layers: int = 2,
):
"""
Args:
input_channels: number of input channels
output_channels: number of output channels.
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_layers: number of convolution layers.
name: module name.
"""
super().__init__()
self.conv_layers = nn.ModuleList(
[
nn.Conv1d(
in_channels=input_channels if i == 0 else output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
)
for i in range(num_layers)
]
)
self.activation_fn: Callable = get_activation_fn(activation_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, conv_layer in enumerate(self.conv_layers):
x = conv_layer(x)
if i < len(self.conv_layers) - 1:
x = self.activation_fn(x)
return x
class TorchUNET1DSegmentationHead(nn.Module):
"""
Torch adaptation of UNET1DSegmentationHead in
trix.layers.heads.unet_segmentation_head.py
"""
def __init__(
self,
num_classes: int,
input_embed_dim: int,
output_channels_list: Tuple[int, ...] = (64, 128, 256),
activation_fn: str = "swish",
num_conv_layers_per_block: int = 2,
upsampling_interpolation_method: str = "nearest",
):
"""
Args:
num_classes: number of classes to segment
output_channels_list: list of the number of output channel at each level of
the UNET
activation_fn: name of the activation function to use.
Should be one of "gelu",
"gelu-no-approx", "relu", "swish", "silu", "sin".
num_conv_layers_per_block: number of convolution layers per block.
upsampling_interpolation_method: Method to be used for
interpolation in upsampling blocks. Should be one of "nearest",
"linear", "cubic", "lanczos3", "lanczos5".
"""
super().__init__()
input_channels_list = (input_embed_dim,) + output_channels_list[:-1]
self.num_pooling_layers = len(output_channels_list)
self.downsample_blocks = nn.ModuleList(
[
TorchDownSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
)
for input_channels, output_channels in zip(
input_channels_list, output_channels_list
)
]
)
input_channels_list = (output_channels_list[-1],) + tuple(
list(reversed(output_channels_list))[:-1]
)
self.upsample_blocks = nn.ModuleList(
[
TorchUpSample1D(
input_channels=input_channels,
output_channels=output_channels,
activation_fn=activation_fn,
num_layers=num_conv_layers_per_block,
interpolation_method=upsampling_interpolation_method,
)
for input_channels, output_channels in zip(
input_channels_list, reversed(output_channels_list)
)
]
)
self.final_block = TorchFinalConv1D(
activation_fn=activation_fn,
input_channels=output_channels_list[0],
output_channels=num_classes * 2,
num_layers=num_conv_layers_per_block,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[-1] % 2**self.num_pooling_layers:
raise ValueError(
"Input length must be divisible by 2 to the power of "
"the number of pooling layers."
)
hiddens = []
for downsample_block in self.downsample_blocks:
x, hidden = downsample_block(x)
hiddens.append(hidden)
for upsample_block, hidden in zip(self.upsample_blocks, reversed(hiddens)):
x = upsample_block(x) + hidden
x = self.final_block(x)
return x
class TorchUNetHead(nn.Module):
"""
Torch adaptation of UNetHead in
genomics_research/segmentnt/layers/segmentation_head.py
"""
def __init__(
self,
features: List[str],
num_classes: int = 2,
embed_dimension: int = 1024,
nucl_per_token: int = 6,
num_layers: int = 2,
remove_cls_token: bool = True,
):
"""
Args:
features (List[str]): List of features names.
num_classes (int): Number of classes.
embed_dimension (int): Embedding dimension.
nucl_per_token (int): Number of nucleotides per token.
num_layers (int): Number of layers.
remove_cls_token (bool): Whether to remove the CLS token.
name: Name the layer. Defaults to None.
"""
super().__init__()
self._num_features = len(features)
self._num_classes = num_classes
self.nucl_per_token = nucl_per_token
self.remove_cls_token = remove_cls_token
self.unet = TorchUNET1DSegmentationHead(
num_classes=embed_dimension // 2,
output_channels_list=tuple(
embed_dimension * (2**i) for i in range(num_layers)
),
input_embed_dim=embed_dimension,
)
self.fc = nn.Linear(
embed_dimension,
self.nucl_per_token * self._num_classes * self._num_features,
)
def forward(
self, x: torch.Tensor, sequence_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
if self.remove_cls_token:
x = x[:, 1:]
x = self.unet(x)
x = nn.functional.silu(x)
x = x.transpose(2, 1)
logits = self.fc(x)
batch_size, seq_len, _ = x.shape
logits = logits.view( # noqa
batch_size,
seq_len * self.nucl_per_token,
self._num_features,
self._num_classes,
)
return {"logits": logits}
FEATURES = [
"protein_coding_gene",
"lncRNA",
"exon",
"intron",
"splice_donor",
"splice_acceptor",
"5UTR",
"3UTR",
"CTCF-bound",
"polyA_signal",
"enhancer_Tissue_specific",
"enhancer_Tissue_invariant",
"promoter_Tissue_specific",
"promoter_Tissue_invariant",
]
class SegmentEnformerConfig(PretrainedConfig):
model_type = "segment_enformer"
def __init__(
self,
features: List[str] = FEATURES,
embed_dim: int = 1536,
dim_divisible_by: int = 128,
**kwargs: Dict[str, Any],
) -> None:
self.features = features
self.embed_dim = embed_dim
self.dim_divisible_by = dim_divisible_by
super().__init__(**kwargs)
class SegmentEnformer(PreTrainedModel):
config_class = SegmentEnformerConfig
def __init__(self, config: SegmentEnformerConfig) -> None:
super().__init__(config=config)
enformer = Enformer.from_pretrained("EleutherAI/enformer-official-rough")
self.stem = enformer.stem
self.conv_tower = enformer.conv_tower
self.transformer = enformer.transformer
self.unet_head = TorchUNetHead(
features=config.features,
embed_dimension=config.embed_dim,
nucl_per_token=config.dim_divisible_by,
remove_cls_token=False,
)
def __call__(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "b n d -> b d n")
x = self.stem(x)
x = self.conv_tower(x)
x = rearrange(x, "b d n -> b n d")
x = self.transformer(x)
x = rearrange(x, "b n d -> b d n")
x = self.unet_head(x)
return x