|
|
|
|
|
import math |
|
import torch |
|
from torch import nn, einsum |
|
from einops import rearrange, reduce |
|
from einops.layers.torch import Rearrange |
|
import torch.nn.functional as F |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
def map_values(fn, d): |
|
return {key: fn(values) for key, values in d.items()} |
|
|
|
def exponential_linspace_int(start, end, num, divisible_by = 1): |
|
def _round(x): |
|
return int(round(x / divisible_by) * divisible_by) |
|
|
|
base = math.exp(math.log(end / start) / (num - 1)) |
|
return [_round(start * base**i) for i in range(num)] |
|
|
|
def log(t, eps = 1e-20): |
|
return torch.log(t.clamp(min = eps)) |
|
|
|
|
|
|
|
|
|
def pearson_corr_coef(x, y, eps = 1e-8): |
|
x2 = x * x |
|
y2 = y * y |
|
xy = x * y |
|
ex = x.mean(dim = 1) |
|
ey = y.mean(dim = 1) |
|
exy = xy.mean(dim = 1) |
|
ex2 = x2.mean(dim = 1) |
|
ey2 = y2.mean(dim = 1) |
|
r = (exy - ex * ey) / (torch.sqrt(ex2 - (ex * ex)) * torch.sqrt(ey2 - (ey * ey)) + eps) |
|
return r.mean(dim = -1) |
|
|
|
|
|
|
|
def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3.): |
|
max_range = math.log(seq_len) / math.log(2.) |
|
half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device) |
|
half_life = half_life[None, ...] |
|
positions = positions.abs()[..., None] |
|
return torch.exp(-math.log(2.) / half_life * positions) |
|
|
|
def get_positional_features_central_mask(positions, features, seq_len): |
|
center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).float() |
|
center_widths = center_widths - 1 |
|
return (center_widths[None, ...] > positions.abs()[..., None]).float() |
|
|
|
def gamma_pdf(x, concentration, rate): |
|
log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x |
|
log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate)) |
|
return torch.exp(log_unnormalized_prob - log_normalization) |
|
|
|
def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8): |
|
if not exists(stddev): |
|
stddev = seq_len / (2 * features) |
|
|
|
if not exists(start_mean): |
|
start_mean = seq_len / features |
|
|
|
mean = torch.linspace(start_mean, seq_len, features, device = positions.device) |
|
mean = mean[None, ...] |
|
concentration = (mean / stddev) ** 2 |
|
rate = mean / stddev ** 2 |
|
probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate) |
|
probabilities = probabilities + eps |
|
outputs = probabilities / torch.amax(probabilities) |
|
return outputs |
|
|
|
def get_positional_embed(seq_len, feature_size, device): |
|
distances = torch.arange(-seq_len + 1, seq_len, device = device) |
|
|
|
feature_functions = [ |
|
get_positional_features_exponential, |
|
get_positional_features_central_mask, |
|
get_positional_features_gamma |
|
] |
|
|
|
num_components = len(feature_functions) * 2 |
|
|
|
if (feature_size % num_components) != 0: |
|
raise ValueError(f'feature size is not divisible by number of components ({num_components})') |
|
|
|
num_basis_per_class = feature_size // num_components |
|
|
|
embeddings = [] |
|
for fn in feature_functions: |
|
embeddings.append(fn(distances, num_basis_per_class, seq_len)) |
|
|
|
embeddings = torch.cat(embeddings, dim = -1) |
|
embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1) |
|
return embeddings |
|
|
|
def relative_shift(x): |
|
to_pad = torch.zeros_like(x[..., :1]) |
|
x = torch.cat((to_pad, x), dim = -1) |
|
_, h, t1, t2 = x.shape |
|
x = x.reshape(-1, h, t2, t1) |
|
x = x[:, :, 1:, :] |
|
x = x.reshape(-1, h, t1, t2 - 1) |
|
return x[..., :((t2 + 1) // 2)] |
|
|
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(x, **kwargs) + x |
|
|
|
class GELU(nn.Module): |
|
def forward(self, x): |
|
return torch.sigmoid(1.702 * x) * x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self,dim,num_rel_pos_features, |
|
heads = 8, |
|
dim_key = 64, |
|
dim_value = 64, |
|
dropout = 0., |
|
pos_dropout = 0. |
|
): |
|
super().__init__() |
|
self.scale = dim_key ** -0.5 |
|
self.heads = heads |
|
|
|
self.to_q = nn.Linear(dim, dim_key * heads, bias = False) |
|
self.to_k = nn.Linear(dim, dim_key * heads, bias = False) |
|
self.to_v = nn.Linear(dim, dim_value * heads, bias = False) |
|
|
|
self.to_out = nn.Linear(dim_value * heads, dim) |
|
nn.init.zeros_(self.to_out.weight) |
|
nn.init.zeros_(self.to_out.bias) |
|
|
|
|
|
|
|
self.num_rel_pos_features = num_rel_pos_features |
|
|
|
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias = False) |
|
self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) |
|
self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key)) |
|
|
|
self.pos_dropout = nn.Dropout(pos_dropout) |
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
n, h, device = x.shape[-2], self.heads, x.device |
|
|
|
q = self.to_q(x) |
|
k = self.to_k(x) |
|
v = self.to_v(x) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
|
|
|
q = q * self.scale |
|
|
|
content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k) |
|
|
|
positions = get_positional_embed(n, self.num_rel_pos_features, device) |
|
positions = self.pos_dropout(positions) |
|
rel_k = self.to_rel_k(positions) |
|
|
|
rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h) |
|
rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k) |
|
rel_logits = relative_shift(rel_logits) |
|
|
|
logits = content_logits + rel_logits |
|
attn = logits.softmax(dim = -1) |
|
attn = self.attn_dropout(attn) |
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
class Enformer(nn.Module): |
|
def __init__( |
|
self, |
|
dim = 512, |
|
depth = 4, |
|
heads = 6, |
|
attn_dim_key = 64, |
|
dropout_rate = 0.2, |
|
attn_dropout = 0.05, |
|
pos_dropout = 0.01, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
transformer = [] |
|
for _ in range(depth): |
|
transformer.append(nn.Sequential( |
|
Residual(nn.Sequential( |
|
nn.LayerNorm(dim), |
|
Attention( |
|
dim, |
|
heads = heads, |
|
dim_key = attn_dim_key, |
|
dim_value = dim // heads, |
|
dropout = attn_dropout, |
|
pos_dropout = pos_dropout, |
|
num_rel_pos_features = dim // heads |
|
), |
|
nn.Dropout(dropout_rate) |
|
)), |
|
Residual(nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, dim * 2), |
|
nn.Dropout(dropout_rate), |
|
nn.ReLU(), |
|
nn.Linear(dim * 2, dim), |
|
nn.Dropout(dropout_rate) |
|
)) |
|
)) |
|
|
|
self.transformer = nn.Sequential( |
|
|
|
*transformer |
|
) |
|
|
|
def forward(self,x): |
|
|
|
x = self.transformer(x) |
|
|
|
return x |
|
|
|
class CNN(nn.Module): |
|
def __init__(self): |
|
super(CNN, self).__init__() |
|
conv_kernel_size1 = 10 |
|
conv_kernel_size2 = 8 |
|
pool_kernel_size1 = 5 |
|
pool_kernel_size2 = 4 |
|
self.conv_net = nn.Sequential( |
|
nn.Conv1d(5, 256, kernel_size=conv_kernel_size1), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(256, 256, kernel_size=conv_kernel_size1), |
|
|
|
nn.BatchNorm1d(256,track_running_stats=False), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(256, 360, kernel_size=conv_kernel_size2), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(360, 360, kernel_size=conv_kernel_size2), |
|
nn.BatchNorm1d(360,track_running_stats=False), |
|
|
|
nn.ReLU(inplace=True), |
|
nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2), |
|
nn.Dropout(p=0.1), |
|
nn.Conv1d(360, 512, kernel_size=conv_kernel_size2), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.2), |
|
nn.Conv1d(512, 512, kernel_size=conv_kernel_size2), |
|
nn.BatchNorm1d(512,track_running_stats=False), |
|
|
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.2)) |
|
self.num_channels = 512 |
|
def forward(self, x): |
|
out = self.conv_net(x) |
|
return out |
|
|
|
|
|
class AttentionPool(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1) |
|
self.to_attn_logits = nn.Parameter(torch.eye(dim)) |
|
|
|
def forward(self, x): |
|
attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits) |
|
x = self.pool_fn(x) |
|
logits = self.pool_fn(attn_logits) |
|
|
|
attn = logits.softmax(dim = -2) |
|
return (x * attn).sum(dim = -2).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|