tx-model-standalone / modeling.py
Yuto2007's picture
Fix: Use single modeling.py file
8517f78 verified
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
TXModel - Complete Standalone Implementation for HuggingFace
All code in one file - requires ONLY: transformers, torch, safetensors
"""
import math
from typing import Optional, Dict, Any, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
# =============================================================================
# CONFIGURATION
# =============================================================================
class TXConfig(PretrainedConfig):
"""Configuration for TXModel"""
model_type = "tx_model"
def __init__(
self,
vocab_size: int = 30000,
d_model: int = 512,
n_layers: int = 12,
n_heads: int = 8,
expansion_ratio: int = 4,
norm_scheme: str = "pre",
transformer_activation: str = "gelu",
cell_emb_style: str = "cls",
pad_token_id: int = 0,
pad_value: float = 0.0,
num_bins: int = 51,
use_chem_token: bool = False,
attn_config: Optional[Dict] = None,
norm_config: Optional[Dict] = None,
gene_encoder_config: Optional[Dict] = None,
expression_encoder_config: Optional[Dict] = None,
expression_decoder_config: Optional[Dict] = None,
mvc_config: Optional[Dict] = None,
chemical_encoder_config: Optional[Dict] = None,
use_glu: bool = False,
return_gene_embeddings: bool = False,
standard_scale_outputs: bool = False,
keep_first_n_tokens: int = 1,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.expansion_ratio = expansion_ratio
self.norm_scheme = norm_scheme
self.transformer_activation = transformer_activation
self.cell_emb_style = cell_emb_style
self.pad_value = pad_value
self.num_bins = num_bins
self.use_chem_token = use_chem_token
self.keep_first_n_tokens = keep_first_n_tokens
self.return_gene_embeddings = return_gene_embeddings
self.standard_scale_outputs = standard_scale_outputs
self.use_glu = use_glu
self.attn_config = attn_config or {}
self.norm_config = norm_config or {}
self.gene_encoder_config = gene_encoder_config or {}
self.expression_encoder_config = expression_encoder_config or {}
self.expression_decoder_config = expression_decoder_config or {}
self.mvc_config = mvc_config
self.chemical_encoder_config = chemical_encoder_config
# =============================================================================
# MODEL BLOCKS
# =============================================================================
class MultiheadAttention(nn.Module):
"""Multi-head attention with grouped query support"""
def __init__(
self,
d_model: int,
n_heads: int,
kv_n_heads: Optional[int] = None,
dropout: float = 0.0,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads
self.head_dim = d_model // n_heads
self.dropout = dropout
self.n_rep = n_heads // self.kv_n_heads
self.q_proj = nn.Linear(d_model, d_model, device=device)
self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device)
self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device)
self.out_proj = nn.Linear(d_model, d_model, device=device)
self.attn_dropout = nn.Dropout(dropout)
def forward(
self,
x: Tensor,
key_padding_mask: Optional[Tensor] = None,
**kwargs
) -> Tuple[Tensor, None, None]:
batch_size, seq_len, _ = x.shape
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if key_padding_mask is not None:
mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.out_proj(output)
return output, None, None
class TXBlock(nn.Module):
"""Transformer encoder block"""
def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Optional[Dict] = None,
norm_config: Optional[Dict] = None,
dropout: float = 0.0,
activation: str = "gelu",
device: Optional[str] = None,
norm_scheme: str = "pre",
use_glu: bool = False,
**kwargs
):
super().__init__()
attn_config = attn_config or {}
norm_config = norm_config or {}
self.d_model = d_model
self.n_heads = n_heads
self.norm_scheme = norm_scheme
self.use_glu = use_glu
kv_n_heads = attn_config.get("kv_n_heads", n_heads)
self.self_attn = MultiheadAttention(
d_model=d_model,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
dropout=attn_config.get("attn_pdrop", 0.0),
device=device,
)
dim_feedforward = d_model * expansion_ratio
self.up_proj = nn.Linear(d_model, dim_feedforward, device=device)
self.down_proj = nn.Linear(dim_feedforward, d_model, device=device)
if use_glu:
self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device)
eps = norm_config.get("eps", 1e-5)
self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device)
self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device)
self.post_sa_dropout = nn.Dropout(dropout)
self.post_ffn_dropout = nn.Dropout(dropout)
self.activation = {
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"silu": nn.SiLU(),
"leaky_relu": nn.LeakyReLU(),
}.get(activation, nn.GELU())
def forward(
self,
x: Tensor,
key_padding_mask: Optional[Tensor] = None,
**kwargs
) -> Tensor:
if self.norm_scheme == "pre":
x = x + self._sa_block(self.norm1(x), key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
def _sa_block(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
x, _, _ = self.self_attn(x, key_padding_mask=key_padding_mask)
return self.post_sa_dropout(x)
def _ff_block(self, x: Tensor) -> Tensor:
if self.use_glu:
x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
else:
x = self.down_proj(self.activation(self.up_proj(x)))
return self.post_ffn_dropout(x)
class TXEncoder(nn.Module):
"""Stack of transformer encoder layers"""
def __init__(
self,
encoder_layer: TXBlock,
num_layers: int,
use_norm: bool = False,
norm_config: Optional[Dict] = None,
**kwargs
):
super().__init__()
norm_config = norm_config or {}
self.layers = nn.ModuleList([
TXBlock(
d_model=encoder_layer.d_model,
n_heads=encoder_layer.n_heads,
expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model,
norm_scheme=encoder_layer.norm_scheme,
use_glu=encoder_layer.use_glu,
)
for _ in range(num_layers)
])
self.use_norm = use_norm
if use_norm:
eps = norm_config.get("eps", 1e-5)
self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps)
def forward(
self,
total_embs: Tensor,
key_padding_mask: Optional[Tensor] = None,
output_hidden_states: bool = False,
) -> Tuple[Tensor, Optional[list]]:
x = total_embs
hidden_states = [] if output_hidden_states else None
for layer in self.layers:
x = layer(x, key_padding_mask=key_padding_mask)
if output_hidden_states:
hidden_states.append(x)
if self.use_norm:
x = self.norm(x)
return x, hidden_states
class GeneEncoder(nn.Module):
"""Gene embedding encoder"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = 0,
use_norm: bool = False,
**kwargs
):
super().__init__()
self.use_norm = use_norm
self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
self.project = nn.Identity()
if self.use_norm:
self.enc_norm = nn.LayerNorm(embedding_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x)
x = self.project(x)
if self.use_norm:
x = self.enc_norm(x)
return x
class ContinuousValueEncoder(nn.Module):
"""Encode continuous expression values"""
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_value: int = 512,
activation: str = "relu",
use_norm: bool = False,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.linear1 = nn.Linear(1, d_model)
self.activation = {"relu": nn.ReLU(), "gelu": nn.GELU(), "leaky_relu": nn.LeakyReLU()}.get(activation, nn.ReLU())
self.linear2 = nn.Linear(d_model, d_model)
self.use_norm = use_norm
if use_norm:
self.norm = nn.LayerNorm(d_model)
self.max_value = max_value
def forward(self, x: Tensor) -> Tensor:
x = x.unsqueeze(-1)
x = torch.clamp(x, max=self.max_value)
x = self.activation(self.linear1(x))
x = self.linear2(x)
if self.use_norm:
x = self.norm(x)
return self.dropout(x)
class ExprDecoder(nn.Module):
"""Expression value decoder"""
def __init__(
self,
d_model: int,
n_outputs: int = 1,
n_layers: int = 2,
activation: str = "leaky_relu",
):
super().__init__()
self.activation = {"leaky_relu": nn.LeakyReLU(), "relu": nn.ReLU(), "gelu": nn.GELU()}.get(activation, nn.LeakyReLU())
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)])
self.out_proj = nn.Linear(d_model, n_outputs)
def forward(self, x: Tensor) -> Dict[str, Tensor]:
for layer in self.linear_layers:
x = self.activation(layer(x))
pred_value = self.out_proj(x)
if pred_value.shape[-1] == 1:
pred_value = pred_value.squeeze(-1)
return {"pred": pred_value}
class MVCDecoder(nn.Module):
"""Masked value prediction decoder"""
def __init__(
self,
d_model: int,
arch_style: str = "inner product",
query_activation: str = "sigmoid",
scaled_dot_product: bool = False,
):
super().__init__()
self.scaled_dot_product = scaled_dot_product
self.gene2query = nn.Linear(d_model, d_model)
self.query_activation = {"sigmoid": nn.Sigmoid(), "relu": nn.ReLU(), "tanh": nn.Tanh()}.get(query_activation, nn.Sigmoid())
self.W = nn.Linear(d_model, d_model, bias=False)
self.arch_style = arch_style
def forward(self, cell_emb: Tensor, gene_embs: Tensor) -> Dict[str, Tensor]:
query_vecs = self.query_activation(self.gene2query(gene_embs))
cell_emb = cell_emb.unsqueeze(2)
pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2)
if self.scaled_dot_product:
pred_value = pred_value / torch.sqrt(torch.tensor(query_vecs.shape[-1], dtype=pred_value.dtype))
return {"pred": pred_value}
# =============================================================================
# MAIN MODEL
# =============================================================================
class TXModel(nn.Module):
"""Transformer model for genomic data"""
def __init__(self, config: TXConfig):
super().__init__()
self.config = config
self.gene_encoder = GeneEncoder(
config.vocab_size,
config.d_model,
padding_idx=config.pad_token_id,
use_norm=config.gene_encoder_config.get("use_norm", False),
)
self.flag_encoder = nn.Embedding(2, config.d_model)
self.expression_encoder = ContinuousValueEncoder(
d_model=config.d_model,
dropout=config.expression_encoder_config.get("dropout", 0.1),
max_value=config.expression_encoder_config.get("max_value", 512),
activation=config.expression_encoder_config.get("activation", "relu"),
use_norm=config.expression_encoder_config.get("use_norm", False),
)
encoder_layer = TXBlock(
d_model=config.d_model,
n_heads=config.n_heads,
expansion_ratio=config.expansion_ratio,
attn_config=config.attn_config,
norm_config=config.norm_config,
activation=config.transformer_activation,
norm_scheme=config.norm_scheme,
use_glu=config.use_glu,
)
self.transformer_encoder = TXEncoder(
encoder_layer,
config.n_layers,
use_norm=config.norm_scheme == "pre",
norm_config=config.norm_config,
)
self.expression_decoder = ExprDecoder(
d_model=config.d_model,
n_outputs=config.expression_decoder_config.get("n_outputs", 1),
n_layers=config.expression_decoder_config.get("n_layers", 2),
activation=config.expression_decoder_config.get("activation", "leaky_relu"),
)
if config.mvc_config is not None:
self.mvc_decoder = MVCDecoder(
d_model=config.d_model,
arch_style=config.mvc_config.get("arch_style", "inner product"),
query_activation=config.mvc_config.get("query_activation", "sigmoid"),
scaled_dot_product=config.mvc_config.get("scaled_dot_product", False),
)
else:
self.mvc_decoder = None
def forward(
self,
genes: Tensor,
values: Tensor,
gen_masks: Tensor,
key_padding_mask: Tensor,
skip_decoders: bool = False,
output_hidden_states: bool = False,
) -> dict:
# Encode
token_embs = self.gene_encoder(genes)
token_values = self.expression_encoder(values)
token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0)
flag = self.flag_encoder(torch.tensor(1, device=token_embs.device)).reshape(1, 1, -1)
flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag
total_embs = token_embs + token_values + flag_embs
self.cur_gene_token_embs = token_embs
# Transform
transformer_output, hidden_states = self.transformer_encoder(
total_embs=total_embs,
key_padding_mask=key_padding_mask,
output_hidden_states=output_hidden_states,
)
# Cell embedding
cell_emb = transformer_output[:, 0, :]
output = {
"transformer_output": transformer_output,
"cell_emb": cell_emb,
}
if output_hidden_states:
output["hidden_states"] = hidden_states
if skip_decoders:
return output
# Decode
expr_output = self.expression_decoder(transformer_output)
output["expr_preds"] = expr_output["pred"]
if self.mvc_decoder is not None:
mvc_output = self.mvc_decoder(cell_emb, self.cur_gene_token_embs)
output["mvc_output"] = mvc_output["pred"]
return output
# =============================================================================
# HUGGINGFACE WRAPPER
# =============================================================================
class TXPreTrainedModel(PreTrainedModel):
"""Base class for TXModel"""
config_class = TXConfig
base_model_prefix = "tx_model"
supports_gradient_checkpointing = False
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class TXModelForHF(TXPreTrainedModel):
"""
HuggingFace-compatible TXModel
Requires ONLY: transformers, torch, safetensors
"""
def __init__(self, config: TXConfig):
super().__init__(config)
self.tx_model = TXModel(config)
self.post_init()
def forward(
self,
genes: torch.Tensor,
values: torch.Tensor,
gen_masks: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
skip_decoders: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
**kwargs
) -> Union[Tuple, BaseModelOutput]:
if key_padding_mask is None:
key_padding_mask = ~genes.eq(self.config.pad_token_id)
outputs = self.tx_model(
genes=genes,
values=values,
gen_masks=gen_masks,
key_padding_mask=key_padding_mask,
skip_decoders=skip_decoders,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(v for v in outputs.values())
return BaseModelOutput(
last_hidden_state=outputs.get("cell_emb"),
hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
)
def get_input_embeddings(self):
return self.tx_model.gene_encoder.embedding
def set_input_embeddings(self, value):
self.tx_model.gene_encoder.embedding = value
# Aliases
TXForCausalLM = TXModelForHF
AutoModelForCausalLM = TXModelForHF