aletlvl commited on
Commit
a24db0c
·
verified ·
1 Parent(s): 68bf69d

Upload Nicheformer model

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Nicheformer Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Nicheformer
2
+
3
+ Nicheformer is a transformer-based model designed for understanding and predicting cellular niches and their interactions. The model uses masked language modeling to learn representations of cellular contexts and their relationships.
4
+
5
+ ## Model Description
6
+
7
+ Nicheformer is built on a transformer architecture with the following key features:
8
+
9
+ - **Architecture**: Transformer encoder with customizable number of layers and attention heads
10
+ - **Pre-training**: Masked Language Modeling (MLM) objective with dynamic masking
11
+ - **Input Processing**: Handles cell type, assay, and modality information
12
+ - **Positional Encoding**: Supports both learnable and fixed positional embeddings
13
+ - **Masking Strategy**:
14
+ - 80% of selected tokens are replaced with [MASK]
15
+ - 10% are replaced with random tokens
16
+ - 10% remain unchanged
17
+
18
+ ### Model Architecture
19
+
20
+ - Transformer encoder layers: 12
21
+ - Hidden dimension: 512
22
+ - Attention heads: 16
23
+ - Feedforward dimension: 1024
24
+ - Maximum sequence length: 1500
25
+ - Vocabulary size: 25000
26
+ - Masking probability: 15%
27
+
28
+ ## Usage
29
+
30
+ ```python
31
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
32
+
33
+ # Load model and tokenizer
34
+ model = AutoModelForMaskedLM.from_pretrained("your-username/nicheformer")
35
+ tokenizer = AutoTokenizer.from_pretrained("your-username/nicheformer")
36
+
37
+ # Example 1: Manual masking
38
+ masked_text = "The [MASK] cell is an important immune cell type."
39
+ inputs = tokenizer(masked_text, return_tensors="pt")
40
+ outputs = model(**inputs)
41
+
42
+ # Example 2: Automatic masking (typically used during training)
43
+ text = "The T cell is an important immune cell type."
44
+ inputs = tokenizer(text, return_tensors="pt")
45
+ outputs = model(**inputs, apply_masking=True) # This will automatically mask tokens
46
+ ```
47
+
48
+ ## Training Data
49
+
50
+ [Describe the training data used for the model]
51
+
52
+ ## Evaluation Results
53
+
54
+ [Include evaluation metrics and results]
55
+
56
+ ## Limitations
57
+
58
+ [Describe any known limitations or biases of the model]
59
+
60
+ ## Citation
61
+
62
+ If you use this model in your research, please cite:
63
+
64
+ ```bibtex
65
+ [Add citation information]
66
+ ```
67
+
68
+ ## License
69
+
70
+ This model is released under [specify license]
71
+
72
+ ## Contact
73
+
74
+ [Add contact information for questions and issues]
__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .configuration_nicheformer import NicheformerConfig
2
+ from .modeling_nicheformer import (
3
+ NicheformerPreTrainedModel,
4
+ NicheformerModel,
5
+ NicheformerForMaskedLM
6
+ )
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NicheformerForMaskedLM"
4
+ ],
5
+ "assay": true,
6
+ "batch_first": true,
7
+ "cls_classes": 164,
8
+ "context_length": 1500,
9
+ "dim_feedforward": 1024,
10
+ "dim_model": 512,
11
+ "dropout": 0.0,
12
+ "learnable_pe": true,
13
+ "masking_p": 0.15,
14
+ "modality": true,
15
+ "model_type": "nicheformer",
16
+ "n_tokens": 20340,
17
+ "nheads": 16,
18
+ "nlayers": 12,
19
+ "specie": true,
20
+ "supervised_task": null,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.49.0"
23
+ }
configuration_nicheformer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class NicheformerConfig(PretrainedConfig):
4
+ model_type = "nicheformer"
5
+
6
+ def __init__(
7
+ self,
8
+ dim_model=512,
9
+ nheads=16,
10
+ dim_feedforward=1024,
11
+ nlayers=12,
12
+ dropout=0.0,
13
+ batch_first=True,
14
+ masking_p=0.15,
15
+ n_tokens=20340,
16
+ context_length=1500,
17
+ cls_classes=164,
18
+ supervised_task=None,
19
+ learnable_pe=True,
20
+ specie=True,
21
+ assay=True,
22
+ modality=True,
23
+ **kwargs
24
+ ):
25
+ """Initialize NicheformerConfig.
26
+
27
+ Args:
28
+ dim_model: Dimensionality of the model
29
+ nheads: Number of attention heads
30
+ dim_feedforward: Dimensionality of MLPs in attention blocks
31
+ nlayers: Number of transformer layers
32
+ dropout: Dropout probability
33
+ batch_first: Whether batch dimension is first
34
+ masking_p: Probability of masking tokens
35
+ n_tokens: Total number of tokens (excluding auxiliary)
36
+ context_length: Length of the context window
37
+ cls_classes: Number of classification classes
38
+ supervised_task: Type of supervised task
39
+ learnable_pe: Whether to use learnable positional embeddings
40
+ specie: Whether to add specie token
41
+ assay: Whether to add assay token
42
+ modality: Whether to add modality token
43
+ """
44
+ super().__init__(**kwargs)
45
+
46
+ self.dim_model = dim_model
47
+ self.nheads = nheads
48
+ self.dim_feedforward = dim_feedforward
49
+ self.nlayers = nlayers
50
+ self.dropout = dropout
51
+ self.batch_first = batch_first
52
+ self.masking_p = masking_p
53
+ self.n_tokens = n_tokens
54
+ self.context_length = context_length
55
+ self.cls_classes = cls_classes
56
+ self.supervised_task = supervised_task
57
+ self.learnable_pe = learnable_pe
58
+ self.specie = specie
59
+ self.assay = assay
60
+ self.modality = modality
masking.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+
4
+ MASK_TOKEN = 0
5
+ PAD_TOKEN = 1
6
+ CLS_TOKEN = 2
7
+
8
+ def complete_masking(batch, masking_p, n_tokens):
9
+ """Apply masking to input batch for masked language modeling.
10
+
11
+ Args:
12
+ batch (dict): Input batch containing 'input_ids' and 'attention_mask'
13
+ masking_p (float): Probability of masking a token
14
+ n_tokens (int): Total number of tokens in vocabulary
15
+
16
+ Returns:
17
+ dict: Batch with masked indices and masking information
18
+ """
19
+ device = batch['input_ids'].device
20
+ input_ids = batch['input_ids']
21
+ attention_mask = batch['attention_mask']
22
+
23
+ # Create mask tensor (1 for tokens to be masked, 0 otherwise)
24
+ prob = torch.rand(input_ids.shape, device=device)
25
+ mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN)
26
+
27
+ # For masked tokens:
28
+ # - 80% replace with MASK token
29
+ # - 10% replace with random token
30
+ # - 10% keep unchanged
31
+ masked_indices = input_ids.clone()
32
+
33
+ # Calculate number of tokens to be masked
34
+ num_tokens_to_mask = mask.sum().item()
35
+
36
+ # Determine which tokens get which type of masking
37
+ mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8
38
+ random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask
39
+
40
+ # Apply MASK token (80% of masked tokens)
41
+ masked_indices[mask] = torch.where(
42
+ mask_mask,
43
+ torch.tensor(MASK_TOKEN, device=device, dtype=torch.long),
44
+ masked_indices[mask]
45
+ )
46
+
47
+ # Apply random tokens (10% of masked tokens)
48
+ random_tokens = torch.randint(
49
+ 3, n_tokens, # Start from 3 to avoid special tokens
50
+ (random_mask.sum(),),
51
+ device=device,
52
+ dtype=torch.long
53
+ )
54
+ masked_indices[mask][random_mask] = random_tokens
55
+
56
+ # 10% remain unchanged
57
+
58
+ return {
59
+ 'masked_indices': masked_indices,
60
+ 'attention_mask': attention_mask,
61
+ 'mask': mask,
62
+ 'input_ids': input_ids
63
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65fff02ac9bdd978cbb06b158a60d94ec59a96f30e706aed2c56d213d582e603
3
+ size 195851648
modeling_nicheformer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import MaskedLMOutput
5
+ from configuration_nicheformer import NicheformerConfig
6
+ from masking import complete_masking, MASK_TOKEN, PAD_TOKEN, CLS_TOKEN
7
+ import math
8
+
9
+ class PositionalEncoding(nn.Module):
10
+ """Positional encoding using sine and cosine functions."""
11
+
12
+ def __init__(self, d_model: int, max_seq_len: int):
13
+ super().__init__()
14
+ encoding = torch.zeros(max_seq_len, d_model)
15
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
17
+
18
+ encoding[:, 0::2] = torch.sin(position * div_term)
19
+ encoding[:, 1::2] = torch.cos(position * div_term)
20
+ encoding = encoding.unsqueeze(0)
21
+
22
+ self.register_buffer('encoding', encoding, persistent=False)
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ """Add positional encoding to input tensor."""
26
+ return x + self.encoding[:, :x.size(1)]
27
+
28
+ class NicheformerPreTrainedModel(PreTrainedModel):
29
+ """Base class for Nicheformer models."""
30
+
31
+ config_class = NicheformerConfig
32
+ base_model_prefix = "nicheformer"
33
+ supports_gradient_checkpointing = True
34
+
35
+ def _init_weights(self, module):
36
+ if isinstance(module, nn.Linear):
37
+ nn.init.xavier_normal_(module.weight)
38
+ if module.bias is not None:
39
+ nn.init.zeros_(module.bias)
40
+
41
+ class NicheformerModel(NicheformerPreTrainedModel):
42
+ def __init__(self, config: NicheformerConfig):
43
+ super().__init__(config)
44
+
45
+ # Core transformer components
46
+ self.encoder_layer = nn.TransformerEncoderLayer(
47
+ d_model=config.dim_model,
48
+ nhead=config.nheads,
49
+ dim_feedforward=config.dim_feedforward,
50
+ batch_first=config.batch_first,
51
+ dropout=config.dropout,
52
+ layer_norm_eps=1e-12
53
+ )
54
+ self.encoder = nn.TransformerEncoder(
55
+ encoder_layer=self.encoder_layer,
56
+ num_layers=config.nlayers,
57
+ enable_nested_tensor=False
58
+ )
59
+
60
+ # Embedding layers
61
+ self.embeddings = nn.Embedding(
62
+ num_embeddings=config.n_tokens+5,
63
+ embedding_dim=config.dim_model,
64
+ padding_idx=1
65
+ )
66
+
67
+ if config.learnable_pe:
68
+ self.positional_embedding = nn.Embedding(
69
+ num_embeddings=config.context_length,
70
+ embedding_dim=config.dim_model
71
+ )
72
+ self.dropout = nn.Dropout(p=config.dropout)
73
+ self.register_buffer('pos', torch.arange(0, config.context_length, dtype=torch.long))
74
+ else:
75
+ self.positional_embedding = PositionalEncoding(
76
+ d_model=config.dim_model,
77
+ max_seq_len=config.context_length
78
+ )
79
+
80
+ # Initialize weights
81
+ self.post_init()
82
+
83
+ def forward(self, input_ids, attention_mask=None):
84
+ token_embedding = self.embeddings(input_ids)
85
+
86
+ if self.config.learnable_pe:
87
+ pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device))
88
+ embeddings = self.dropout(token_embedding + pos_embedding)
89
+ else:
90
+ embeddings = self.positional_embedding(token_embedding)
91
+
92
+ # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask
93
+ # True indicates positions that will be masked
94
+ if attention_mask is not None:
95
+ attention_mask = ~attention_mask.bool()
96
+
97
+ transformer_output = self.encoder(
98
+ embeddings,
99
+ src_key_padding_mask=attention_mask if attention_mask is not None else None,
100
+ is_causal=False
101
+ )
102
+
103
+ return transformer_output
104
+
105
+ class NicheformerForMaskedLM(NicheformerPreTrainedModel):
106
+ def __init__(self, config: NicheformerConfig):
107
+ super().__init__(config)
108
+
109
+ self.nicheformer = NicheformerModel(config)
110
+ self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False)
111
+ self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens))
112
+
113
+ # Initialize weights
114
+ self.post_init()
115
+
116
+ def forward(
117
+ self,
118
+ input_ids=None,
119
+ attention_mask=None,
120
+ labels=None,
121
+ return_dict=None,
122
+ apply_masking=False,
123
+ ):
124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
125
+
126
+ # Apply masking if requested (typically during training)
127
+ if apply_masking:
128
+ batch = {
129
+ 'input_ids': input_ids,
130
+ 'attention_mask': attention_mask
131
+ }
132
+ masked_batch = complete_masking(batch, self.config.masking_p, self.config.n_tokens)
133
+ input_ids = masked_batch['masked_indices']
134
+ labels = masked_batch['input_ids'] # Original tokens become labels
135
+ mask = masked_batch['mask']
136
+ # Only compute loss on masked tokens and ensure labels are long
137
+ labels = torch.where(mask, labels, torch.tensor(-100, device=labels.device)).long()
138
+
139
+ transformer_output = self.nicheformer(
140
+ input_ids=input_ids,
141
+ attention_mask=attention_mask,
142
+ )
143
+
144
+ prediction_scores = self.classifier_head(transformer_output)
145
+
146
+ masked_lm_loss = None
147
+ if labels is not None:
148
+ loss_fct = nn.CrossEntropyLoss()
149
+ masked_lm_loss = loss_fct(
150
+ prediction_scores.view(-1, self.config.n_tokens),
151
+ labels.view(-1)
152
+ )
153
+
154
+ if not return_dict:
155
+ output = (prediction_scores,) + (transformer_output,)
156
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
157
+
158
+ return MaskedLMOutput(
159
+ loss=masked_lm_loss,
160
+ logits=prediction_scores,
161
+ hidden_states=transformer_output,
162
+ )
tokenization_nicheformer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Union, Tuple
2
+ import numpy as np
3
+ from transformers import PreTrainedTokenizer
4
+ from dataclasses import dataclass
5
+ import torch
6
+ import anndata as ad
7
+ from scipy.sparse import issparse
8
+ import numba
9
+ import os
10
+ import json
11
+
12
+ # Token IDs must match exactly with the original implementation
13
+ PAD_TOKEN = 0
14
+ MASK_TOKEN = 1
15
+ CLS_TOKEN = 2
16
+
17
+ # These mappings preserve the exact token IDs from the original implementation
18
+ MODALITY_DICT = {
19
+ 'dissociated': 3,
20
+ 'spatial': 4,
21
+ }
22
+
23
+ SPECIES_DICT = {
24
+ 'human': 5,
25
+ 'Homo sapiens': 5,
26
+ 'Mus musculus': 6,
27
+ 'mouse': 6,
28
+ }
29
+
30
+ TECHNOLOGY_DICT = {
31
+ "merfish": 7,
32
+ "MERFISH": 7,
33
+ "cosmx": 8,
34
+ "NanoString digital spatial profiling": 8,
35
+ "Xenium": 9,
36
+ "10x 5' v2": 10,
37
+ "10x 3' v3": 11,
38
+ "10x 3' v2": 12,
39
+ "10x 5' v1": 13,
40
+ "10x 3' v1": 14,
41
+ "10x 3' transcription profiling": 15,
42
+ "10x transcription profiling": 15,
43
+ "10x 5' transcription profiling": 16,
44
+ "CITE-seq": 17,
45
+ "Smart-seq v4": 18,
46
+ }
47
+
48
+ def sf_normalize(X: np.ndarray) -> np.ndarray:
49
+ """Size factor normalize to 10k counts."""
50
+ X = X.copy()
51
+ counts = np.array(X.sum(axis=1))
52
+ # avoid zero division error
53
+ counts += counts == 0.
54
+ # normalize to 10000 counts
55
+ scaling_factor = 10000. / counts
56
+
57
+ if issparse(X):
58
+ from scipy.sparse import sparsefuncs
59
+ sparsefuncs.inplace_row_scale(X, scaling_factor)
60
+ else:
61
+ np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)
62
+
63
+ return X
64
+
65
+ @numba.jit(nopython=True, nogil=True)
66
+ def _sub_tokenize_data(x: np.ndarray, max_seq_len: int = -1, aux_tokens: int = 30) -> np.ndarray:
67
+ """Tokenize the input gene vector."""
68
+ scores_final = np.empty((x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]))
69
+ for i, cell in enumerate(x):
70
+ nonzero_mask = np.nonzero(cell)[0]
71
+ sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:max_seq_len]
72
+ sorted_indices = sorted_indices + aux_tokens
73
+ if max_seq_len:
74
+ scores = np.zeros(max_seq_len, dtype=np.int32)
75
+ else:
76
+ scores = np.zeros_like(cell, dtype=np.int32)
77
+ scores[:len(sorted_indices)] = sorted_indices.astype(np.int32)
78
+ scores_final[i, :] = scores
79
+ return scores_final
80
+
81
+ class NicheformerTokenizer(PreTrainedTokenizer):
82
+ """Tokenizer for Nicheformer that handles single-cell data."""
83
+
84
+ model_input_names = ["input_ids", "attention_mask"]
85
+ vocab_files_names = {"vocab_file": "vocab.json"}
86
+
87
+ modality_dict = MODALITY_DICT
88
+ species_dict = SPECIES_DICT
89
+ technology_dict = TECHNOLOGY_DICT
90
+
91
+ def __init__(
92
+ self,
93
+ vocab_file=None,
94
+ max_length: int = 1500,
95
+ aux_tokens: int = 30,
96
+ median_counts_per_gene: Optional[np.ndarray] = None,
97
+ gene_names: Optional[List[str]] = None,
98
+ **kwargs
99
+ ):
100
+ # Initialize base vocabulary
101
+ self._vocabulary = {
102
+ "[PAD]": PAD_TOKEN,
103
+ "[MASK]": MASK_TOKEN,
104
+ "[CLS]": CLS_TOKEN,
105
+ }
106
+
107
+ if vocab_file is not None:
108
+ with open(vocab_file, 'r') as f:
109
+ self._vocabulary.update(json.load(f))
110
+ else:
111
+ # Add modality tokens
112
+ for name, idx in self.modality_dict.items():
113
+ self._vocabulary[f"[MODALITY_{name}]"] = idx
114
+ # Add species tokens
115
+ for name, idx in self.species_dict.items():
116
+ if name in ["Homo sapiens", "Mus musculus"]:
117
+ continue # Skip redundant names
118
+ self._vocabulary[f"[SPECIES_{name}]"] = idx
119
+ # Add technology tokens
120
+ for name, idx in self.technology_dict.items():
121
+ if name in ["MERFISH", "10x transcription profiling"]:
122
+ continue # Skip redundant names
123
+ clean_name = name.lower().replace(" ", "_").replace("'", "_")
124
+ self._vocabulary[f"[TECH_{clean_name}]"] = idx
125
+
126
+ # Add gene tokens if provided
127
+ if gene_names is not None:
128
+ for i, gene in enumerate(gene_names):
129
+ self._vocabulary[gene] = i + aux_tokens
130
+ # Save vocabulary
131
+ os.makedirs('to_hf', exist_ok=True)
132
+ with open('to_hf/vocab.json', 'w') as f:
133
+ json.dump(self._vocabulary, f, indent=4)
134
+
135
+ super().__init__(**kwargs)
136
+
137
+ self.max_length = max_length
138
+ self.aux_tokens = aux_tokens
139
+ self.median_counts_per_gene = median_counts_per_gene
140
+ self.gene_names = gene_names
141
+
142
+ # Set up special token mappings
143
+ self._pad_token = "[PAD]"
144
+ self._mask_token = "[MASK]"
145
+ self._cls_token = "[CLS]"
146
+
147
+ def get_vocab(self) -> Dict[str, int]:
148
+ """Returns the vocabulary mapping."""
149
+ return self._vocabulary.copy()
150
+
151
+ def _tokenize(self, text: str) -> List[str]:
152
+ """Tokenize text input."""
153
+ # This tokenizer doesn't handle text input directly
154
+ raise NotImplementedError("This tokenizer only works with gene expression data")
155
+
156
+ def _convert_token_to_id(self, token: str) -> int:
157
+ """Convert token to ID."""
158
+ # First check special token mappings
159
+ if token in self.modality_dict:
160
+ return self.modality_dict[token]
161
+ if token in self.species_dict:
162
+ return self.species_dict[token]
163
+ if token in self.technology_dict:
164
+ return self.technology_dict[token]
165
+ # Then check vocabulary
166
+ return self._vocabulary.get(token, self._vocabulary["[PAD]"])
167
+
168
+ def _convert_id_to_token(self, index: int) -> str:
169
+ """Convert ID to token."""
170
+ # First check special token mappings
171
+ for token, idx in self.modality_dict.items():
172
+ if idx == index:
173
+ return token
174
+ for token, idx in self.species_dict.items():
175
+ if idx == index:
176
+ return token
177
+ for token, idx in self.technology_dict.items():
178
+ if idx == index:
179
+ return token
180
+ # Then check vocabulary
181
+ for token, idx in self._vocabulary.items():
182
+ if idx == index:
183
+ return token
184
+ return "[PAD]"
185
+
186
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
187
+ """Save the vocabulary to a file."""
188
+ vocab_file = os.path.join(
189
+ save_directory,
190
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
191
+ )
192
+
193
+ with open(vocab_file, "w", encoding="utf-8") as f:
194
+ json.dump(self._vocabulary, f, ensure_ascii=False)
195
+
196
+ return (vocab_file,)
197
+
198
+ def _tokenize_gene_expression(self, x: np.ndarray) -> np.ndarray:
199
+ """Tokenize gene expression matrix.
200
+
201
+ Args:
202
+ x: Gene expression matrix (cells x genes)
203
+
204
+ Returns:
205
+ Tokenized matrix
206
+ """
207
+ # Handle sparse input
208
+ if issparse(x):
209
+ x = x.toarray()
210
+
211
+ # Normalize and scale
212
+ x = np.nan_to_num(x)
213
+ x = sf_normalize(x)
214
+ if self.median_counts_per_gene is not None:
215
+ median_counts = self.median_counts_per_gene.copy()
216
+ median_counts += median_counts == 0
217
+ x = x / median_counts.reshape((1, -1))
218
+
219
+ # Convert to tokens
220
+ tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
221
+
222
+ return tokens.astype(np.int32)
223
+
224
+ def __call__(
225
+ self,
226
+ adata: Optional[ad.AnnData] = None,
227
+ gene_expression: Optional[Union[np.ndarray, List[float]]] = None,
228
+ modality: Optional[str] = None,
229
+ species: Optional[str] = None,
230
+ technology: Optional[str] = None,
231
+ **kwargs
232
+ ) -> Dict[str, torch.Tensor]:
233
+ """Convert inputs to model inputs.
234
+
235
+ Args:
236
+ adata: AnnData object
237
+ gene_expression: Gene expression matrix if not using AnnData
238
+ modality: Modality type
239
+ species: Species type
240
+ technology: Technology/assay type
241
+
242
+ Returns:
243
+ Dictionary with model inputs
244
+ """
245
+ if adata is not None:
246
+ # Get expression matrix
247
+ if issparse(adata.X):
248
+ x = adata.X.toarray()
249
+ else:
250
+ x = adata.X
251
+
252
+ # Get metadata for each cell if not provided
253
+ if modality is None and 'modality' in adata.obs:
254
+ modality = adata.obs['modality'].values
255
+ if species is None and 'specie' in adata.obs:
256
+ species = adata.obs['specie'].values
257
+ if technology is None and 'assay' in adata.obs:
258
+ technology = adata.obs['assay'].values
259
+
260
+ elif gene_expression is not None:
261
+ x = np.array(gene_expression)
262
+ if len(x.shape) == 1:
263
+ x = x.reshape(1, -1)
264
+ # For single gene expression input, convert scalar metadata to arrays
265
+ if modality is not None:
266
+ modality = np.array([modality])
267
+ if species is not None:
268
+ species = np.array([species])
269
+ if technology is not None:
270
+ technology = np.array([technology])
271
+ else:
272
+ raise ValueError("Either adata or gene_expression must be provided")
273
+
274
+ # Tokenize gene expression
275
+ token_ids = self._tokenize_gene_expression(x)
276
+ n_cells = token_ids.shape[0]
277
+
278
+ # Add special tokens for each cell
279
+ special_tokens = np.zeros((n_cells, 3), dtype=np.int32) # 3 for modality, species, technology
280
+ special_token_mask = np.zeros((n_cells, 3), dtype=bool) # Track which tokens are actually present
281
+
282
+ if modality is not None:
283
+ special_tokens[:, 0] = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality]
284
+ special_token_mask[:, 0] = True
285
+
286
+ if species is not None:
287
+ special_tokens[:, 1] = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species]
288
+ special_token_mask[:, 1] = True
289
+
290
+ if technology is not None:
291
+ special_tokens[:, 2] = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology]
292
+ special_token_mask[:, 2] = True
293
+
294
+ # Only keep the special tokens that are present (have True in mask)
295
+ special_tokens = special_tokens[:, special_token_mask[0]]
296
+
297
+ if special_tokens.size > 0:
298
+ token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
299
+
300
+ # Create attention mask
301
+ attention_mask = (token_ids != self._vocabulary["[PAD]"])
302
+
303
+ return {
304
+ "input_ids": torch.tensor(token_ids, dtype=torch.long),
305
+ "attention_mask": torch.tensor(attention_mask)
306
+ }
307
+
308
+ def get_vocab_size(self) -> int:
309
+ """Get vocabulary size."""
310
+ if self.gene_names is not None:
311
+ return len(self.gene_names) + self.aux_tokens
312
+ return max(
313
+ max(self.modality_dict.values()),
314
+ max(self.species_dict.values()),
315
+ max(self.technology_dict.values())
316
+ ) + 1
317
+
318
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
319
+ """Convert a sequence of tokens to a string. Not used for gene expression."""
320
+ raise NotImplementedError("This tokenizer only works with gene expression data")
321
+
322
+ def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
323
+ """Build model inputs from a sequence by adding special tokens."""
324
+ # For gene expression data, special tokens are handled in __call__
325
+ return token_ids_0
326
+
327
+ def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
328
+ """Get list where entries are [1] if a token is [special] else [0]."""
329
+ # Consider tokens < aux_tokens as special
330
+ return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0]
vocab.json ADDED
The diff for this file is too large to render. See raw diff