| | |
| | """ |
| | Standalone implementation of TXModel without external dependencies. |
| | Only requires: torch, transformers, safetensors |
| | """ |
| |
|
| | from typing import Optional, Union, Tuple |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor, nn |
| |
|
| | from blocks_standalone import ( |
| | ChemEncoder, |
| | ContinuousValueEncoder, |
| | ExprDecoder, |
| | GeneEncoder, |
| | MVCDecoder, |
| | TXBlock, |
| | TXEncoder, |
| | ) |
| |
|
| |
|
| | class TXModel(nn.Module): |
| | """Standalone Transformer model for genomic data""" |
| | |
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | d_model: int, |
| | n_layers: int, |
| | n_heads: int, |
| | expansion_ratio: int, |
| | pad_token_id: int, |
| | pad_value: float, |
| | num_bins: int, |
| | norm_scheme: str = "pre", |
| | transformer_activation: str = "gelu", |
| | cell_emb_style: str = "cls", |
| | use_chem_token: bool = False, |
| | attn_config: Optional[dict] = None, |
| | norm_config: Optional[dict] = None, |
| | gene_encoder_config: Optional[dict] = None, |
| | expression_encoder_config: Optional[dict] = None, |
| | expression_decoder_config: Optional[dict] = None, |
| | mvc_config: Optional[dict] = None, |
| | chemical_encoder_config: Optional[dict] = None, |
| | use_glu: bool = False, |
| | return_gene_embeddings: bool = False, |
| | keep_first_n_tokens: int = 1, |
| | device: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | |
| | self.model_type = "Transformer" |
| | self.device = device |
| | self.vocab_size = vocab_size |
| | self.n_layers = n_layers |
| | self.n_heads = n_heads |
| | self.d_model = d_model |
| | self.expansion_ratio = expansion_ratio |
| | self.norm_scheme = norm_scheme |
| | self.transformer_activation = transformer_activation |
| | self.use_chem_token = use_chem_token |
| | self.cell_emb_style = cell_emb_style |
| | self.pad_token_id = pad_token_id |
| | self.pad_value = pad_value |
| | self.n_input_bins = num_bins |
| | self.keep_first_n_tokens = keep_first_n_tokens |
| | self.return_gene_embeddings = return_gene_embeddings |
| | |
| | if attn_config is None: |
| | attn_config = {} |
| | if norm_config is None: |
| | norm_config = {} |
| | if gene_encoder_config is None: |
| | gene_encoder_config = {"use_norm": False} |
| | if expression_encoder_config is None: |
| | expression_encoder_config = {} |
| | if expression_decoder_config is None: |
| | expression_decoder_config = {} |
| | |
| | |
| | self.gene_encoder = GeneEncoder( |
| | self.vocab_size, |
| | self.d_model, |
| | padding_idx=self.pad_token_id, |
| | use_norm=gene_encoder_config.get("use_norm", False), |
| | gene_encoder_cfg=gene_encoder_config, |
| | ) |
| | |
| | |
| | self.flag_encoder = nn.Embedding(2, self.d_model) |
| | |
| | |
| | self.expression_encoder = ContinuousValueEncoder( |
| | d_model=self.d_model, |
| | dropout=expression_encoder_config.get("dropout", 0.1), |
| | max_value=expression_encoder_config.get("max_value", 512), |
| | activation=expression_encoder_config.get("activation", "relu"), |
| | use_norm=expression_encoder_config.get("use_norm", False), |
| | ) |
| | |
| | |
| | if self.use_chem_token: |
| | if chemical_encoder_config is None: |
| | chemical_encoder_config = {} |
| | self.chem_encoder = ChemEncoder( |
| | d_out=self.d_model, |
| | padding_idx=chemical_encoder_config.get("padding_idx", 0), |
| | activation=chemical_encoder_config.get("activation", "leaky_relu"), |
| | freeze=chemical_encoder_config.get("freeze", False), |
| | num_drugs=chemical_encoder_config.get("num_drugs", 1000), |
| | fp_dim=chemical_encoder_config.get("fp_dim", 2048), |
| | ) |
| | |
| | |
| | encoder_layer = TXBlock( |
| | d_model=self.d_model, |
| | n_heads=self.n_heads, |
| | expansion_ratio=self.expansion_ratio, |
| | attn_config=attn_config, |
| | norm_config=norm_config, |
| | activation=self.transformer_activation, |
| | device=self.device, |
| | norm_scheme=self.norm_scheme, |
| | use_glu=use_glu, |
| | ) |
| | |
| | self.transformer_encoder = TXEncoder( |
| | encoder_layer, |
| | self.n_layers, |
| | use_norm=self.norm_scheme == "pre", |
| | norm_config=norm_config, |
| | attn_config=attn_config, |
| | ) |
| | |
| | |
| | self.expression_decoder = ExprDecoder( |
| | d_model=self.d_model, |
| | n_outputs=expression_decoder_config.get("n_outputs", 1), |
| | n_layers=expression_decoder_config.get("n_layers", 2), |
| | activation=expression_decoder_config.get("activation", "leaky_relu"), |
| | ) |
| | |
| | |
| | if mvc_config is not None: |
| | self.mvc_decoder = MVCDecoder( |
| | d_model=self.d_model, |
| | arch_style=mvc_config.get("arch_style", "inner product"), |
| | query_activation=mvc_config.get("query_activation", "sigmoid"), |
| | scaled_dot_product=mvc_config.get("scaled_dot_product", False), |
| | ) |
| | else: |
| | self.mvc_decoder = None |
| | |
| | def transformer_generate( |
| | self, |
| | genes: Tensor, |
| | values: Tensor, |
| | gen_masks: Tensor, |
| | key_padding_mask: Tensor, |
| | drug_ids: Optional[Tensor] = None, |
| | output_hidden_states: bool = False, |
| | ) -> Union[Tensor, Tuple[Tensor, list]]: |
| | |
| | |
| | token_embs = self.gene_encoder(genes) |
| | |
| | |
| | token_values = self.expression_encoder(values) |
| | token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0) |
| | |
| | |
| | flag = self.flag_encoder( |
| | torch.tensor(1, device=token_embs.device) |
| | ).reshape(1, 1, -1) |
| | flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag |
| | |
| | |
| | total_embs = token_embs + token_values + flag_embs |
| | |
| | |
| | if self.use_chem_token and drug_ids is not None: |
| | drug_embs = self.chem_encoder(drug_ids) |
| | total_embs[:, 1, :] = drug_embs |
| | |
| | |
| | self.cur_gene_token_embs = token_embs |
| | |
| | |
| | output, hidden_states = self.transformer_encoder( |
| | total_embs=total_embs, |
| | key_padding_mask=key_padding_mask, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| | |
| | return output, hidden_states |
| | |
| | def forward( |
| | self, |
| | genes: Tensor, |
| | values: Tensor, |
| | gen_masks: Tensor, |
| | key_padding_mask: Tensor, |
| | drug_ids: Optional[Tensor] = None, |
| | skip_decoders: bool = False, |
| | output_hidden_states: bool = False, |
| | ) -> dict: |
| | |
| | |
| | transformer_output, hidden_states = self.transformer_generate( |
| | genes, values, gen_masks, key_padding_mask, |
| | drug_ids, output_hidden_states |
| | ) |
| | |
| | |
| | output = { |
| | "transformer_output": transformer_output, |
| | } |
| | |
| | if output_hidden_states: |
| | output["hidden_states"] = hidden_states |
| | |
| | |
| | if self.cell_emb_style == "cls": |
| | cell_emb = transformer_output[:, 0, :] |
| | elif self.cell_emb_style == "avg-pool": |
| | |
| | mask = key_padding_mask.unsqueeze(-1).float() |
| | cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| | elif self.cell_emb_style == "w-pool": |
| | |
| | mask = key_padding_mask.unsqueeze(-1).float() |
| | cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
| | else: |
| | cell_emb = transformer_output[:, 0, :] |
| | |
| | output["cell_emb"] = cell_emb |
| | |
| | |
| | if self.return_gene_embeddings: |
| | output["gene_embeddings"] = transformer_output |
| | |
| | |
| | if skip_decoders: |
| | return output |
| | |
| | |
| | expr_output = self.expression_decoder(transformer_output) |
| | output["expr_preds"] = expr_output["pred"] |
| | |
| | |
| | if self.mvc_decoder is not None: |
| | mvc_output = self.mvc_decoder( |
| | cell_emb, |
| | self.cur_gene_token_embs, |
| | ) |
| | output["mvc_output"] = mvc_output["pred"] |
| | |
| | return output |
| | |
| | @classmethod |
| | def from_pretrained(cls, model_path: str, **kwargs): |
| | """Load model from pretrained weights""" |
| | from safetensors.torch import load_file |
| | import json |
| | from pathlib import Path |
| | |
| | model_path = Path(model_path) |
| | |
| | |
| | with open(model_path / "config.json", "r") as f: |
| | config = json.load(f) |
| | |
| | |
| | model = cls( |
| | vocab_size=config["vocab_size"], |
| | d_model=config["d_model"], |
| | n_layers=config["n_layers"], |
| | n_heads=config["n_heads"], |
| | expansion_ratio=config["expansion_ratio"], |
| | pad_token_id=config["pad_token_id"], |
| | pad_value=config["pad_value"], |
| | num_bins=config["num_bins"], |
| | norm_scheme=config.get("norm_scheme", "pre"), |
| | transformer_activation=config.get("transformer_activation", "gelu"), |
| | cell_emb_style=config.get("cell_emb_style", "cls"), |
| | use_chem_token=config.get("use_chem_token", False), |
| | attn_config=config.get("attn_config"), |
| | norm_config=config.get("norm_config"), |
| | gene_encoder_config=config.get("gene_encoder_config"), |
| | expression_encoder_config=config.get("expression_encoder_config"), |
| | expression_decoder_config=config.get("expression_decoder_config"), |
| | mvc_config=config.get("mvc_config"), |
| | chemical_encoder_config=config.get("chemical_encoder_config"), |
| | use_glu=config.get("use_glu", False), |
| | return_gene_embeddings=config.get("return_gene_embeddings", False), |
| | keep_first_n_tokens=config.get("keep_first_n_tokens", 1), |
| | ) |
| | |
| | |
| | state_dict = load_file(model_path / "model.safetensors") |
| | |
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | new_key = k |
| | if k.startswith('model.tx_model.'): |
| | new_key = k[14:] |
| | elif k.startswith('tx_model.'): |
| | new_key = k[9:] |
| | elif k.startswith('model.'): |
| | new_key = k[6:] |
| | new_state_dict[new_key] = v |
| | |
| | model.load_state_dict(new_state_dict, strict=False) |
| | |
| | return model |
| |
|