# some of following codes are borrowed from https://github.com/lucidrains/enformer-pytorch import math import torch from torch import nn, einsum from einops import rearrange, reduce from einops.layers.torch import Rearrange import torch.nn.functional as F def exists(val): return val is not None def default(val, d): return val if exists(val) else d def map_values(fn, d): return {key: fn(values) for key, values in d.items()} def exponential_linspace_int(start, end, num, divisible_by = 1): def _round(x): return int(round(x / divisible_by) * divisible_by) base = math.exp(math.log(end / start) / (num - 1)) return [_round(start * base**i) for i in range(num)] def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) # losses and metrics def pearson_corr_coef(x, y, eps = 1e-8): x2 = x * x y2 = y * y xy = x * y ex = x.mean(dim = 1) ey = y.mean(dim = 1) exy = xy.mean(dim = 1) ex2 = x2.mean(dim = 1) ey2 = y2.mean(dim = 1) r = (exy - ex * ey) / (torch.sqrt(ex2 - (ex * ex)) * torch.sqrt(ey2 - (ey * ey)) + eps) return r.mean(dim = -1) # relative positional encoding functions def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3.): max_range = math.log(seq_len) / math.log(2.) half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device) half_life = half_life[None, ...] positions = positions.abs()[..., None] return torch.exp(-math.log(2.) / half_life * positions) def get_positional_features_central_mask(positions, features, seq_len): center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).float() center_widths = center_widths - 1 return (center_widths[None, ...] > positions.abs()[..., None]).float() def gamma_pdf(x, concentration, rate): log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate)) return torch.exp(log_unnormalized_prob - log_normalization) def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8): if not exists(stddev): stddev = seq_len / (2 * features) if not exists(start_mean): start_mean = seq_len / features mean = torch.linspace(start_mean, seq_len, features, device = positions.device) mean = mean[None, ...] concentration = (mean / stddev) ** 2 rate = mean / stddev ** 2 probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate) probabilities = probabilities + eps outputs = probabilities / torch.amax(probabilities) return outputs def get_positional_embed(seq_len, feature_size, device): distances = torch.arange(-seq_len + 1, seq_len, device = device) feature_functions = [ get_positional_features_exponential, get_positional_features_central_mask, get_positional_features_gamma ] num_components = len(feature_functions) * 2 if (feature_size % num_components) != 0: raise ValueError(f'feature size is not divisible by number of components ({num_components})') num_basis_per_class = feature_size // num_components embeddings = [] for fn in feature_functions: embeddings.append(fn(distances, num_basis_per_class, seq_len)) embeddings = torch.cat(embeddings, dim = -1) embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1) return embeddings def relative_shift(x): to_pad = torch.zeros_like(x[..., :1]) x = torch.cat((to_pad, x), dim = -1) _, h, t1, t2 = x.shape x = x.reshape(-1, h, t2, t1) x = x[:, :, 1:, :] x = x.reshape(-1, h, t1, t2 - 1) return x[..., :((t2 + 1) // 2)] # classes class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class GELU(nn.Module): def forward(self, x): return torch.sigmoid(1.702 * x) * x class Attention(nn.Module): def __init__( self,dim,num_rel_pos_features, heads = 8, dim_key = 64, dim_value = 64, dropout = 0., pos_dropout = 0. ): super().__init__() self.scale = dim_key ** -0.5 self.heads = heads self.to_q = nn.Linear(dim, dim_key * heads, bias = False) self.to_k = nn.Linear(dim, dim_key * heads, bias = False) self.to_v = nn.Linear(dim, dim_value * heads, bias = False) self.to_out = nn.Linear(dim_value * heads, dim) nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias) # relative positional encoding self.num_rel_pos_features = num_rel_pos_features self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias = False) self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) self.pos_dropout = nn.Dropout(pos_dropout) self.attn_dropout = nn.Dropout(dropout) def forward(self, x): n, h, device = x.shape[-2], self.heads, x.device q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) q = q * self.scale content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k) positions = get_positional_embed(n, self.num_rel_pos_features, device) positions = self.pos_dropout(positions) rel_k = self.to_rel_k(positions) rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h) rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k) rel_logits = relative_shift(rel_logits) logits = content_logits + rel_logits attn = logits.softmax(dim = -1) attn = self.attn_dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class Enformer(nn.Module): def __init__( self, dim = 512, depth = 4, heads = 6, attn_dim_key = 64, dropout_rate = 0.2, attn_dropout = 0.05, pos_dropout = 0.01, ): super().__init__() self.dim = dim transformer = [] for _ in range(depth): transformer.append(nn.Sequential( Residual(nn.Sequential( nn.LayerNorm(dim), Attention( dim, heads = heads, dim_key = attn_dim_key, dim_value = dim // heads, dropout = attn_dropout, pos_dropout = pos_dropout, num_rel_pos_features = dim // heads ), nn.Dropout(dropout_rate) )), Residual(nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * 2), nn.Dropout(dropout_rate), nn.ReLU(), nn.Linear(dim * 2, dim), nn.Dropout(dropout_rate) )) )) self.transformer = nn.Sequential( # Rearrange('b d n -> b n d'), *transformer ) def forward(self,x): x = self.transformer(x) return x class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() conv_kernel_size1 = 10 conv_kernel_size2 = 8 pool_kernel_size1 = 5 pool_kernel_size2 = 4 self.conv_net = nn.Sequential( nn.Conv1d(5, 256, kernel_size=conv_kernel_size1), nn.ReLU(inplace=True), nn.Dropout(p=0.1), nn.Conv1d(256, 256, kernel_size=conv_kernel_size1), # nn.GroupNorm(16, 256), nn.BatchNorm1d(256,track_running_stats=False), nn.ReLU(inplace=True), nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1), nn.Dropout(p=0.1), nn.Conv1d(256, 360, kernel_size=conv_kernel_size2), nn.ReLU(inplace=True), nn.Dropout(p=0.1), nn.Conv1d(360, 360, kernel_size=conv_kernel_size2), nn.BatchNorm1d(360,track_running_stats=False), # nn.GroupNorm(36, 360), nn.ReLU(inplace=True), nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2), nn.Dropout(p=0.1), nn.Conv1d(360, 512, kernel_size=conv_kernel_size2), nn.ReLU(inplace=True), nn.Dropout(p=0.2), nn.Conv1d(512, 512, kernel_size=conv_kernel_size2), nn.BatchNorm1d(512,track_running_stats=False), # nn.GroupNorm(32, 512), nn.ReLU(inplace=True), nn.Dropout(p=0.2)) self.num_channels = 512 def forward(self, x): out = self.conv_net(x) return out class AttentionPool(nn.Module): def __init__(self, dim): super().__init__() self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1) self.to_attn_logits = nn.Parameter(torch.eye(dim)) def forward(self, x): attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits) x = self.pool_fn(x) logits = self.pool_fn(attn_logits) attn = logits.softmax(dim = -2) return (x * attn).sum(dim = -2).squeeze() # class MSE_loss(nn.Module): # def __init__(self): # super().__init__() # def forward(self,pred,target):