|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch_geometric.nn.conv import GCNConv |
|
|
from performer_pytorch import Performer |
|
|
|
|
|
|
|
|
model_params = { |
|
|
"dim": 320, |
|
|
"bins": 10, |
|
|
"gb_repeat": 1, |
|
|
"p_repeat": 2, |
|
|
"bin_head": 8, |
|
|
"full_head": 4, |
|
|
"gene_length": 19357 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalExprEmbedding(nn.Module): |
|
|
""" |
|
|
Rotary Expression Embedding (REE): |
|
|
Converts continuous gene expression values into a sinusoidal |
|
|
embedding usable by Performer/Transformer blocks. Deterministic, |
|
|
not learned. Masked positions (-10) → zero vector. |
|
|
""" |
|
|
def __init__(self, dim, mask_token=-10): |
|
|
super().__init__() |
|
|
self.mask_token = mask_token |
|
|
self.inv_freq = nn.Parameter( |
|
|
1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)), |
|
|
requires_grad=False |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
mask = (x == self.mask_token).nonzero(as_tuple=False) |
|
|
x = torch.einsum("bi,j->bij", x, self.inv_freq) |
|
|
x = torch.cat([x.sin(), x.cos()], dim=-1) |
|
|
x[mask[:, 0], mask[:, 1]] = 0 |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GBFormer(nn.Module): |
|
|
""" |
|
|
A single GBFormer block: |
|
|
- LayerNorm |
|
|
- GCNConv (gene-gene propagation) |
|
|
- Binning by learned importance score |
|
|
- Local Performer per-bin |
|
|
- Global Performer |
|
|
""" |
|
|
def __init__(self, dim, gene_length, bin_head, full_head, bins, p_repeat): |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.bins = bins |
|
|
self.bin_head = bin_head |
|
|
self.full_head = full_head |
|
|
self.p_repeat = p_repeat |
|
|
|
|
|
self.layernorm = nn.LayerNorm(dim) |
|
|
self.gcn = GCNConv(dim, dim, cached=True, add_self_loops=False) |
|
|
|
|
|
|
|
|
self.which_bin = nn.Linear(dim, 1) |
|
|
|
|
|
|
|
|
self.bin_layers = nn.ModuleList([ |
|
|
Performer( |
|
|
dim=dim, |
|
|
heads=bin_head, |
|
|
depth=1, |
|
|
dim_head=dim // bin_head, |
|
|
attn_dropout=0.2, |
|
|
ff_dropout=0.2 |
|
|
) |
|
|
for _ in range(bins) |
|
|
]) |
|
|
|
|
|
|
|
|
self.global_layers = nn.Sequential(*[ |
|
|
Performer( |
|
|
dim=dim, |
|
|
heads=full_head, |
|
|
depth=1, |
|
|
dim_head=dim // full_head |
|
|
) |
|
|
for _ in range(p_repeat) |
|
|
]) |
|
|
|
|
|
def forward(self, x, graph): |
|
|
B, G, D = x.shape |
|
|
|
|
|
x = self.layernorm(x) |
|
|
x = x + self.gcn(x, graph) |
|
|
|
|
|
if self.bins > 0: |
|
|
scores = self.which_bin(x).squeeze(-1) |
|
|
order = torch.argsort(scores, dim=1, descending=True) |
|
|
order_full = order.unsqueeze(-1).expand(-1, -1, D) |
|
|
|
|
|
x_sorted = x.gather(1, order_full) |
|
|
bin_size = (G - 1) // self.bins + 1 |
|
|
chunks = torch.split(x_sorted, bin_size, dim=1) |
|
|
|
|
|
processed = [ |
|
|
layer(chunk) |
|
|
for chunk, layer in zip(chunks, self.bin_layers) |
|
|
] |
|
|
|
|
|
x_cat = torch.cat(processed, dim=1) |
|
|
x = torch.empty_like(x_cat).scatter_(1, order_full, x_cat) |
|
|
|
|
|
x = self.global_layers(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BulkFormer(nn.Module): |
|
|
""" |
|
|
CancerTranscriptome-Mini-48M: |
|
|
A compact BulkFormer-style masked-expression model. |
|
|
Combines: |
|
|
- ESM2 gene identity embeddings |
|
|
- Rotary Expression Embeddings (REE) |
|
|
- Graph Convolution (GCNConv) |
|
|
- Local/global Performer attention |
|
|
- Optional intermediate repr_layers for feature extraction |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
graph, |
|
|
gene_emb, |
|
|
gene_length, |
|
|
bin_head=4, |
|
|
full_head=4, |
|
|
bins=10, |
|
|
gb_repeat=1, |
|
|
p_repeat=1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.graph = graph |
|
|
self.gene_length = gene_length |
|
|
|
|
|
|
|
|
self.gene_emb = nn.Parameter(gene_emb) |
|
|
self.gene_proj = nn.Sequential( |
|
|
nn.Linear(gene_emb.shape[1], 4 * dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(4 * dim, dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.expr_emb = PositionalExprEmbedding(dim) |
|
|
|
|
|
|
|
|
self.mix = nn.Sequential( |
|
|
nn.Linear(dim, 4 * dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(4 * dim, dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.gb_blocks = nn.ModuleList([ |
|
|
GBFormer(dim, gene_length, bin_head, full_head, bins, p_repeat) |
|
|
for _ in range(gb_repeat) |
|
|
]) |
|
|
|
|
|
self.final_norm = nn.LayerNorm(dim) |
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(dim, 4 * dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(4 * dim, 1), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
def forward(self, x, repr_layers=None): |
|
|
B, G = x.shape |
|
|
hidden = {} |
|
|
|
|
|
x = ( |
|
|
self.expr_emb(x) + |
|
|
self.gene_proj(self.gene_emb) + |
|
|
torch.zeros(B, 1, self.dim, device=x.device) |
|
|
) |
|
|
|
|
|
x = self.mix(x) |
|
|
|
|
|
for i, block in enumerate(self.gb_blocks): |
|
|
x = block(x, self.graph) |
|
|
if repr_layers and i in repr_layers: |
|
|
hidden[i] = x |
|
|
|
|
|
x = self.final_norm(x) |
|
|
out = self.head(x).squeeze(-1) |
|
|
|
|
|
if repr_layers: |
|
|
return out, hidden |
|
|
return out |
|
|
|