|
import torch |
|
import numpy as np |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from collections import OrderedDict |
|
from typing import Callable |
|
from timm.models.layers import Mlp |
|
from fairseq_signals_backbone.models.wav2vec2.wav2vec2_cmsc import Wav2Vec2CMSCModel, Wav2Vec2CMSCConfig |
|
from lightning import LightningModule |
|
from transformers import PreTrainedModel |
|
from .configuration_MELP_Encoder import MELPEncoderConfig |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm (with cast back to input dtype).""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
return x.to(orig_type) |
|
|
|
|
|
class AttentionalPooler(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
context_dim: int, |
|
n_head: int = 8, |
|
n_queries: int = 256, |
|
norm_layer: Callable = LayerNorm, |
|
): |
|
super().__init__() |
|
self.query = nn.Parameter(torch.randn(n_queries, d_model)) |
|
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) |
|
self.ln_q = norm_layer(d_model) |
|
self.ln_k = norm_layer(context_dim) |
|
|
|
def forward(self, x: torch.Tensor): |
|
N = x.shape[0] |
|
x = self.ln_k(x) |
|
q = self.ln_q(self.query) |
|
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] |
|
return out |
|
|
|
|
|
def off_diagonal(x): |
|
|
|
n, m = x.shape |
|
assert n == m |
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|
|
|
|
class ECGFMModel(LightningModule): |
|
def __init__(self, |
|
model_size: str = "small", |
|
shared_emb_dim: int = 256, |
|
embed_dim_caption: int = 768, |
|
use_attentional_pool_contrast: bool = False, |
|
use_attentional_pool_caption: bool = False, |
|
n_queries_contrast: int = 10, |
|
n_queries_caption: int = 128, |
|
attn_pooler_heads: int = 8, |
|
norm_layer: nn.Module = nn.LayerNorm, |
|
proj: str = "linear", |
|
drop: float = 0., |
|
proj_bias: bool = False, |
|
num_leads: int = 12, |
|
softmax_temperature: float = 0.1, |
|
lambd: float = 0.0051, |
|
*args, |
|
**kwargs): |
|
|
|
"""" Implementation of ECG-FM model. |
|
Using the Wave2Vec2 model as the ECG encoder: CNN + Transformer |
|
|
|
""" |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.shared_emb_dim = shared_emb_dim |
|
self.num_leads = num_leads |
|
self.temperature = softmax_temperature |
|
|
|
if model_size == "small": |
|
self.encoder_embed_dim = 768 |
|
self.encoder_attention_heads = 12 |
|
self.encoder_layers = 8 |
|
self.encoder_ffn_embed_dim = 3072 |
|
elif model_size == "base": |
|
self.encoder_embed_dim = 768 |
|
self.encoder_attention_heads = 12 |
|
self.encoder_layers = 12 |
|
self.encoder_ffn_embed_dim = 3072 |
|
elif model_size == "large": |
|
self.encoder_embed_dim = 1024 |
|
self.encoder_attention_heads = 16 |
|
self.encoder_layers = 24 |
|
self.encoder_ffn_embed_dim = 4096 |
|
else: |
|
raise ValueError(f"Unknown model size: {model_size}") |
|
print("Using ECG encoder with the following configuration:") |
|
print(f"encoder_embed_dim: {self.encoder_embed_dim}") |
|
print(f"encoder_attention_heads: {self.encoder_attention_heads}") |
|
print(f"encoder_layers: {self.encoder_layers}") |
|
print(f"encoder_ffn_embed_dim: {self.encoder_ffn_embed_dim}") |
|
|
|
self.init_ecg_encoder() |
|
|
|
self.embed_dim_caption = embed_dim_caption |
|
self.use_attentional_pool_contrast = use_attentional_pool_contrast |
|
self.use_attentional_pool_caption = use_attentional_pool_caption |
|
|
|
head_layers = OrderedDict() |
|
prev_chs = self.ecg_encoder.cfg.encoder_embed_dim |
|
if use_attentional_pool_contrast: |
|
scale = prev_chs ** -0.5 |
|
self.attn_pool_contrast = AttentionalPooler( |
|
d_model=shared_emb_dim, |
|
context_dim=prev_chs, |
|
n_head=attn_pooler_heads, |
|
n_queries=n_queries_contrast) |
|
self.ln_contrast = norm_layer(shared_emb_dim) |
|
self.proj_contrast = nn.Parameter(scale * torch.randn(shared_emb_dim, shared_emb_dim)) |
|
else: |
|
assert proj, 'projection layer needed if not using attentional pooling.' |
|
|
|
if proj == 'linear': |
|
head_layers['drop'] = nn.Dropout(drop) |
|
head_layers['proj'] = nn.Linear(prev_chs, shared_emb_dim, bias=proj_bias) |
|
elif proj == 'mlp': |
|
head_layers['mlp'] = Mlp(prev_chs, 2 * shared_emb_dim, shared_emb_dim, drop=(drop, 0), bias=(True, proj_bias)) |
|
|
|
self.head = nn.Sequential(head_layers) |
|
|
|
if use_attentional_pool_caption: |
|
self.attn_pool_caption = AttentionalPooler( |
|
d_model=embed_dim_caption, context_dim=prev_chs, n_head=attn_pooler_heads, n_queries=n_queries_caption) |
|
self.ln_caption = norm_layer(embed_dim_caption) |
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
self.bn = nn.BatchNorm1d(768, affine=False) |
|
self.lambd = lambd |
|
|
|
def init_ecg_encoder(self): |
|
|
|
cfg = Wav2Vec2CMSCConfig( |
|
apply_mask = True, |
|
mask_prob = 0.65, |
|
quantize_targets = True, |
|
final_dim = 256, |
|
dropout_input = 0.1, |
|
dropout_features = 0.1, |
|
feature_grad_mult = 0.1, |
|
encoder_embed_dim = self.encoder_embed_dim, |
|
encoder_attention_heads = self.encoder_attention_heads, |
|
in_d = 12, |
|
encoder_layers = self.encoder_layers, |
|
encoder_ffn_embed_dim = self.encoder_ffn_embed_dim |
|
) |
|
self.ecg_encoder = Wav2Vec2CMSCModel(cfg) |
|
|
|
def _global_pool(self, x): |
|
return torch.mean(x, dim=1) |
|
|
|
@torch.no_grad() |
|
|
|
def ext_ecg_emb(self, ecg, normalize=False): |
|
assert ecg.dim() == 3, "Input tensor must be 3D" |
|
|
|
ecg_out = self.ecg_encoder(source=ecg, mask=False, features_only=True) |
|
features = ecg_out["x"] |
|
|
|
if self.use_attentional_pool_contrast: |
|
pooled = self.attn_pool_contrast(features) |
|
pooled = self.ln_contrast(pooled) |
|
pooled = torch.mean(pooled, dim=1) |
|
else: |
|
pooled = self._global_pool(features) |
|
|
|
if normalize: |
|
pooled = F.normalize(pooled, p=2, dim=-1) |
|
|
|
return pooled |
|
|
|
def _encode_ecg(self, ecg): |
|
assert ecg.dim() == 3, "Input tensor must be 3D" |
|
ecg_out = self.ecg_encoder(source=ecg, mask=False, features_only=True) |
|
|
|
|
|
features = ecg_out["x"] |
|
|
|
if self.use_attentional_pool_contrast: |
|
|
|
pooled = self.attn_pool_contrast(features) |
|
pooled = self.ln_contrast(pooled) |
|
pooled = pooled @ self.proj_contrast.unsqueeze(0) |
|
pooled_beat = pooled.clone() |
|
pooled = torch.mean(pooled, dim=1) |
|
else: |
|
pooled = self._global_pool(features) |
|
pooled = self.head(features) |
|
|
|
tokens = None |
|
if self.use_attentional_pool_caption: |
|
tokens = self.attn_pool_caption(features) |
|
tokens = self.ln_caption(tokens) |
|
else: |
|
tokens = None |
|
|
|
return pooled, pooled_beat, tokens |
|
|
|
def encode_ecg(self, ecg): |
|
ecg_latent, _, _ = self._encode_ecg(ecg) |
|
return ecg_latent |
|
|
|
|
|
class MELPEncoderModel(PreTrainedModel): |
|
config_class = MELPEncoderConfig |
|
|
|
def __init__(self, config: MELPEncoderConfig): |
|
super().__init__(config) |
|
|
|
self.ecg_encoder = ECGFMModel( |
|
model_size=config.model_size, |
|
shared_emb_dim=config.shared_emb_dim, |
|
embed_dim_caption=config.embed_dim_caption, |
|
use_attentional_pool_contrast=config.use_attentional_pool_contrast, |
|
use_attentional_pool_caption=config.use_attentional_pool_caption, |
|
n_queries_contrast=config.n_queries_contrast, |
|
n_queries_caption=config.n_queries_caption, |
|
attn_pooler_heads=config.attn_pooler_heads, |
|
proj=config.proj, |
|
drop=config.drop, |
|
proj_bias=config.proj_bias, |
|
num_leads=config.num_leads, |
|
) |
|
|
|
def forward(self, tensor: torch.Tensor) -> torch.Tensor: |
|
proj_ecg_emb, ecg_beat_emb, ecg_token_emb = self.ecg_encoder._encode_ecg(tensor) |
|
|
|
return { |
|
"proj_ecg_emb": proj_ecg_emb, |
|
"ecg_beat_emb": ecg_beat_emb, |
|
"ecg_token_emb": ecg_token_emb |
|
} |
|
|