EPCOT / pretrain /track /layers.py
drjieliu's picture
Upload 31 files
2e9cf56
# some of following codes are borrowed from https://github.com/lucidrains/enformer-pytorch
import math
import torch
from torch import nn, einsum
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
import torch.nn.functional as F
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def map_values(fn, d):
return {key: fn(values) for key, values in d.items()}
def exponential_linspace_int(start, end, num, divisible_by = 1):
def _round(x):
return int(round(x / divisible_by) * divisible_by)
base = math.exp(math.log(end / start) / (num - 1))
return [_round(start * base**i) for i in range(num)]
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# losses and metrics
def pearson_corr_coef(x, y, eps = 1e-8):
x2 = x * x
y2 = y * y
xy = x * y
ex = x.mean(dim = 1)
ey = y.mean(dim = 1)
exy = xy.mean(dim = 1)
ex2 = x2.mean(dim = 1)
ey2 = y2.mean(dim = 1)
r = (exy - ex * ey) / (torch.sqrt(ex2 - (ex * ex)) * torch.sqrt(ey2 - (ey * ey)) + eps)
return r.mean(dim = -1)
# relative positional encoding functions
def get_positional_features_exponential(positions, features, seq_len, min_half_life = 3.):
max_range = math.log(seq_len) / math.log(2.)
half_life = 2 ** torch.linspace(min_half_life, max_range, features, device = positions.device)
half_life = half_life[None, ...]
positions = positions.abs()[..., None]
return torch.exp(-math.log(2.) / half_life * positions)
def get_positional_features_central_mask(positions, features, seq_len):
center_widths = 2 ** torch.arange(1, features + 1, device = positions.device).float()
center_widths = center_widths - 1
return (center_widths[None, ...] > positions.abs()[..., None]).float()
def gamma_pdf(x, concentration, rate):
log_unnormalized_prob = torch.xlogy(concentration - 1., x) - rate * x
log_normalization = (torch.lgamma(concentration) - concentration * torch.log(rate))
return torch.exp(log_unnormalized_prob - log_normalization)
def get_positional_features_gamma(positions, features, seq_len, stddev = None, start_mean = None, eps = 1e-8):
if not exists(stddev):
stddev = seq_len / (2 * features)
if not exists(start_mean):
start_mean = seq_len / features
mean = torch.linspace(start_mean, seq_len, features, device = positions.device)
mean = mean[None, ...]
concentration = (mean / stddev) ** 2
rate = mean / stddev ** 2
probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate)
probabilities = probabilities + eps
outputs = probabilities / torch.amax(probabilities)
return outputs
def get_positional_embed(seq_len, feature_size, device):
distances = torch.arange(-seq_len + 1, seq_len, device = device)
feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
get_positional_features_gamma
]
num_components = len(feature_functions) * 2
if (feature_size % num_components) != 0:
raise ValueError(f'feature size is not divisible by number of components ({num_components})')
num_basis_per_class = feature_size // num_components
embeddings = []
for fn in feature_functions:
embeddings.append(fn(distances, num_basis_per_class, seq_len))
embeddings = torch.cat(embeddings, dim = -1)
embeddings = torch.cat((embeddings, torch.sign(distances)[..., None] * embeddings), dim = -1)
return embeddings
def relative_shift(x):
to_pad = torch.zeros_like(x[..., :1])
x = torch.cat((to_pad, x), dim = -1)
_, h, t1, t2 = x.shape
x = x.reshape(-1, h, t2, t1)
x = x[:, :, 1:, :]
x = x.reshape(-1, h, t1, t2 - 1)
return x[..., :((t2 + 1) // 2)]
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class GELU(nn.Module):
def forward(self, x):
return torch.sigmoid(1.702 * x) * x
class Attention(nn.Module):
def __init__(
self,dim,num_rel_pos_features,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.,
pos_dropout = 0.
):
super().__init__()
self.scale = dim_key ** -0.5
self.heads = heads
self.to_q = nn.Linear(dim, dim_key * heads, bias = False)
self.to_k = nn.Linear(dim, dim_key * heads, bias = False)
self.to_v = nn.Linear(dim, dim_value * heads, bias = False)
self.to_out = nn.Linear(dim_value * heads, dim)
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
# relative positional encoding
self.num_rel_pos_features = num_rel_pos_features
self.to_rel_k = nn.Linear(num_rel_pos_features, dim_key * heads, bias = False)
self.rel_content_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
self.rel_pos_bias = nn.Parameter(torch.randn(1, heads, 1, dim_key))
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x):
n, h, device = x.shape[-2], self.heads, x.device
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale
content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)
positions = get_positional_embed(n, self.num_rel_pos_features, device)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)
rel_k = rearrange(rel_k, 'n (h d) -> h n d', h = h)
rel_logits = einsum('b h i d, h j d -> b h i j', q + self.rel_pos_bias, rel_k)
rel_logits = relative_shift(rel_logits)
logits = content_logits + rel_logits
attn = logits.softmax(dim = -1)
attn = self.attn_dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Enformer(nn.Module):
def __init__(
self,
dim = 512,
depth = 4,
heads = 6,
attn_dim_key = 64,
dropout_rate = 0.2,
attn_dropout = 0.05,
pos_dropout = 0.01,
):
super().__init__()
self.dim = dim
transformer = []
for _ in range(depth):
transformer.append(nn.Sequential(
Residual(nn.Sequential(
nn.LayerNorm(dim),
Attention(
dim,
heads = heads,
dim_key = attn_dim_key,
dim_value = dim // heads,
dropout = attn_dropout,
pos_dropout = pos_dropout,
num_rel_pos_features = dim // heads
),
nn.Dropout(dropout_rate)
)),
Residual(nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 2),
nn.Dropout(dropout_rate),
nn.ReLU(),
nn.Linear(dim * 2, dim),
nn.Dropout(dropout_rate)
))
))
self.transformer = nn.Sequential(
# Rearrange('b d n -> b n d'),
*transformer
)
def forward(self,x):
x = self.transformer(x)
return x
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
conv_kernel_size1 = 10
conv_kernel_size2 = 8
pool_kernel_size1 = 5
pool_kernel_size2 = 4
self.conv_net = nn.Sequential(
nn.Conv1d(5, 256, kernel_size=conv_kernel_size1),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Conv1d(256, 256, kernel_size=conv_kernel_size1),
# nn.GroupNorm(16, 256),
nn.BatchNorm1d(256,track_running_stats=False),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=pool_kernel_size1, stride=pool_kernel_size1),
nn.Dropout(p=0.1),
nn.Conv1d(256, 360, kernel_size=conv_kernel_size2),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Conv1d(360, 360, kernel_size=conv_kernel_size2),
nn.BatchNorm1d(360,track_running_stats=False),
# nn.GroupNorm(36, 360),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=pool_kernel_size2, stride=pool_kernel_size2),
nn.Dropout(p=0.1),
nn.Conv1d(360, 512, kernel_size=conv_kernel_size2),
nn.ReLU(inplace=True),
nn.Dropout(p=0.2),
nn.Conv1d(512, 512, kernel_size=conv_kernel_size2),
nn.BatchNorm1d(512,track_running_stats=False),
# nn.GroupNorm(32, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.2))
self.num_channels = 512
def forward(self, x):
out = self.conv_net(x)
return out
class AttentionPool(nn.Module):
def __init__(self, dim):
super().__init__()
self.pool_fn = Rearrange('b (n p) d-> b n p d', n=1)
self.to_attn_logits = nn.Parameter(torch.eye(dim))
def forward(self, x):
attn_logits = einsum('b n d, d e -> b n e', x, self.to_attn_logits)
x = self.pool_fn(x)
logits = self.pool_fn(attn_logits)
attn = logits.softmax(dim = -2)
return (x * attn).sum(dim = -2).squeeze()
# class MSE_loss(nn.Module):
# def __init__(self):
# super().__init__()
# def forward(self,pred,target):