Upload Nicheformer model
Browse files- LICENSE +21 -0
- README.md +74 -0
- __init__.py +6 -0
- config.json +23 -0
- configuration_nicheformer.py +60 -0
- masking.py +63 -0
- model.safetensors +3 -0
- modeling_nicheformer.py +162 -0
- tokenization_nicheformer.py +330 -0
- vocab.json +0 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Nicheformer Contributors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Nicheformer
|
| 2 |
+
|
| 3 |
+
Nicheformer is a transformer-based model designed for understanding and predicting cellular niches and their interactions. The model uses masked language modeling to learn representations of cellular contexts and their relationships.
|
| 4 |
+
|
| 5 |
+
## Model Description
|
| 6 |
+
|
| 7 |
+
Nicheformer is built on a transformer architecture with the following key features:
|
| 8 |
+
|
| 9 |
+
- **Architecture**: Transformer encoder with customizable number of layers and attention heads
|
| 10 |
+
- **Pre-training**: Masked Language Modeling (MLM) objective with dynamic masking
|
| 11 |
+
- **Input Processing**: Handles cell type, assay, and modality information
|
| 12 |
+
- **Positional Encoding**: Supports both learnable and fixed positional embeddings
|
| 13 |
+
- **Masking Strategy**:
|
| 14 |
+
- 80% of selected tokens are replaced with [MASK]
|
| 15 |
+
- 10% are replaced with random tokens
|
| 16 |
+
- 10% remain unchanged
|
| 17 |
+
|
| 18 |
+
### Model Architecture
|
| 19 |
+
|
| 20 |
+
- Transformer encoder layers: 12
|
| 21 |
+
- Hidden dimension: 512
|
| 22 |
+
- Attention heads: 16
|
| 23 |
+
- Feedforward dimension: 1024
|
| 24 |
+
- Maximum sequence length: 1500
|
| 25 |
+
- Vocabulary size: 25000
|
| 26 |
+
- Masking probability: 15%
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 32 |
+
|
| 33 |
+
# Load model and tokenizer
|
| 34 |
+
model = AutoModelForMaskedLM.from_pretrained("your-username/nicheformer")
|
| 35 |
+
tokenizer = AutoTokenizer.from_pretrained("your-username/nicheformer")
|
| 36 |
+
|
| 37 |
+
# Example 1: Manual masking
|
| 38 |
+
masked_text = "The [MASK] cell is an important immune cell type."
|
| 39 |
+
inputs = tokenizer(masked_text, return_tensors="pt")
|
| 40 |
+
outputs = model(**inputs)
|
| 41 |
+
|
| 42 |
+
# Example 2: Automatic masking (typically used during training)
|
| 43 |
+
text = "The T cell is an important immune cell type."
|
| 44 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 45 |
+
outputs = model(**inputs, apply_masking=True) # This will automatically mask tokens
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Training Data
|
| 49 |
+
|
| 50 |
+
[Describe the training data used for the model]
|
| 51 |
+
|
| 52 |
+
## Evaluation Results
|
| 53 |
+
|
| 54 |
+
[Include evaluation metrics and results]
|
| 55 |
+
|
| 56 |
+
## Limitations
|
| 57 |
+
|
| 58 |
+
[Describe any known limitations or biases of the model]
|
| 59 |
+
|
| 60 |
+
## Citation
|
| 61 |
+
|
| 62 |
+
If you use this model in your research, please cite:
|
| 63 |
+
|
| 64 |
+
```bibtex
|
| 65 |
+
[Add citation information]
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## License
|
| 69 |
+
|
| 70 |
+
This model is released under [specify license]
|
| 71 |
+
|
| 72 |
+
## Contact
|
| 73 |
+
|
| 74 |
+
[Add contact information for questions and issues]
|
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_nicheformer import NicheformerConfig
|
| 2 |
+
from .modeling_nicheformer import (
|
| 3 |
+
NicheformerPreTrainedModel,
|
| 4 |
+
NicheformerModel,
|
| 5 |
+
NicheformerForMaskedLM
|
| 6 |
+
)
|
config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"NicheformerForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"assay": true,
|
| 6 |
+
"batch_first": true,
|
| 7 |
+
"cls_classes": 164,
|
| 8 |
+
"context_length": 1500,
|
| 9 |
+
"dim_feedforward": 1024,
|
| 10 |
+
"dim_model": 512,
|
| 11 |
+
"dropout": 0.0,
|
| 12 |
+
"learnable_pe": true,
|
| 13 |
+
"masking_p": 0.15,
|
| 14 |
+
"modality": true,
|
| 15 |
+
"model_type": "nicheformer",
|
| 16 |
+
"n_tokens": 20340,
|
| 17 |
+
"nheads": 16,
|
| 18 |
+
"nlayers": 12,
|
| 19 |
+
"specie": true,
|
| 20 |
+
"supervised_task": null,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.49.0"
|
| 23 |
+
}
|
configuration_nicheformer.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class NicheformerConfig(PretrainedConfig):
|
| 4 |
+
model_type = "nicheformer"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
dim_model=512,
|
| 9 |
+
nheads=16,
|
| 10 |
+
dim_feedforward=1024,
|
| 11 |
+
nlayers=12,
|
| 12 |
+
dropout=0.0,
|
| 13 |
+
batch_first=True,
|
| 14 |
+
masking_p=0.15,
|
| 15 |
+
n_tokens=20340,
|
| 16 |
+
context_length=1500,
|
| 17 |
+
cls_classes=164,
|
| 18 |
+
supervised_task=None,
|
| 19 |
+
learnable_pe=True,
|
| 20 |
+
specie=True,
|
| 21 |
+
assay=True,
|
| 22 |
+
modality=True,
|
| 23 |
+
**kwargs
|
| 24 |
+
):
|
| 25 |
+
"""Initialize NicheformerConfig.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dim_model: Dimensionality of the model
|
| 29 |
+
nheads: Number of attention heads
|
| 30 |
+
dim_feedforward: Dimensionality of MLPs in attention blocks
|
| 31 |
+
nlayers: Number of transformer layers
|
| 32 |
+
dropout: Dropout probability
|
| 33 |
+
batch_first: Whether batch dimension is first
|
| 34 |
+
masking_p: Probability of masking tokens
|
| 35 |
+
n_tokens: Total number of tokens (excluding auxiliary)
|
| 36 |
+
context_length: Length of the context window
|
| 37 |
+
cls_classes: Number of classification classes
|
| 38 |
+
supervised_task: Type of supervised task
|
| 39 |
+
learnable_pe: Whether to use learnable positional embeddings
|
| 40 |
+
specie: Whether to add specie token
|
| 41 |
+
assay: Whether to add assay token
|
| 42 |
+
modality: Whether to add modality token
|
| 43 |
+
"""
|
| 44 |
+
super().__init__(**kwargs)
|
| 45 |
+
|
| 46 |
+
self.dim_model = dim_model
|
| 47 |
+
self.nheads = nheads
|
| 48 |
+
self.dim_feedforward = dim_feedforward
|
| 49 |
+
self.nlayers = nlayers
|
| 50 |
+
self.dropout = dropout
|
| 51 |
+
self.batch_first = batch_first
|
| 52 |
+
self.masking_p = masking_p
|
| 53 |
+
self.n_tokens = n_tokens
|
| 54 |
+
self.context_length = context_length
|
| 55 |
+
self.cls_classes = cls_classes
|
| 56 |
+
self.supervised_task = supervised_task
|
| 57 |
+
self.learnable_pe = learnable_pe
|
| 58 |
+
self.specie = specie
|
| 59 |
+
self.assay = assay
|
| 60 |
+
self.modality = modality
|
masking.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
MASK_TOKEN = 0
|
| 5 |
+
PAD_TOKEN = 1
|
| 6 |
+
CLS_TOKEN = 2
|
| 7 |
+
|
| 8 |
+
def complete_masking(batch, masking_p, n_tokens):
|
| 9 |
+
"""Apply masking to input batch for masked language modeling.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
batch (dict): Input batch containing 'input_ids' and 'attention_mask'
|
| 13 |
+
masking_p (float): Probability of masking a token
|
| 14 |
+
n_tokens (int): Total number of tokens in vocabulary
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
dict: Batch with masked indices and masking information
|
| 18 |
+
"""
|
| 19 |
+
device = batch['input_ids'].device
|
| 20 |
+
input_ids = batch['input_ids']
|
| 21 |
+
attention_mask = batch['attention_mask']
|
| 22 |
+
|
| 23 |
+
# Create mask tensor (1 for tokens to be masked, 0 otherwise)
|
| 24 |
+
prob = torch.rand(input_ids.shape, device=device)
|
| 25 |
+
mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN)
|
| 26 |
+
|
| 27 |
+
# For masked tokens:
|
| 28 |
+
# - 80% replace with MASK token
|
| 29 |
+
# - 10% replace with random token
|
| 30 |
+
# - 10% keep unchanged
|
| 31 |
+
masked_indices = input_ids.clone()
|
| 32 |
+
|
| 33 |
+
# Calculate number of tokens to be masked
|
| 34 |
+
num_tokens_to_mask = mask.sum().item()
|
| 35 |
+
|
| 36 |
+
# Determine which tokens get which type of masking
|
| 37 |
+
mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8
|
| 38 |
+
random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask
|
| 39 |
+
|
| 40 |
+
# Apply MASK token (80% of masked tokens)
|
| 41 |
+
masked_indices[mask] = torch.where(
|
| 42 |
+
mask_mask,
|
| 43 |
+
torch.tensor(MASK_TOKEN, device=device, dtype=torch.long),
|
| 44 |
+
masked_indices[mask]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Apply random tokens (10% of masked tokens)
|
| 48 |
+
random_tokens = torch.randint(
|
| 49 |
+
3, n_tokens, # Start from 3 to avoid special tokens
|
| 50 |
+
(random_mask.sum(),),
|
| 51 |
+
device=device,
|
| 52 |
+
dtype=torch.long
|
| 53 |
+
)
|
| 54 |
+
masked_indices[mask][random_mask] = random_tokens
|
| 55 |
+
|
| 56 |
+
# 10% remain unchanged
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
'masked_indices': masked_indices,
|
| 60 |
+
'attention_mask': attention_mask,
|
| 61 |
+
'mask': mask,
|
| 62 |
+
'input_ids': input_ids
|
| 63 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:65fff02ac9bdd978cbb06b158a60d94ec59a96f30e706aed2c56d213d582e603
|
| 3 |
+
size 195851648
|
modeling_nicheformer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import PreTrainedModel
|
| 4 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 5 |
+
from configuration_nicheformer import NicheformerConfig
|
| 6 |
+
from masking import complete_masking, MASK_TOKEN, PAD_TOKEN, CLS_TOKEN
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
class PositionalEncoding(nn.Module):
|
| 10 |
+
"""Positional encoding using sine and cosine functions."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, d_model: int, max_seq_len: int):
|
| 13 |
+
super().__init__()
|
| 14 |
+
encoding = torch.zeros(max_seq_len, d_model)
|
| 15 |
+
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
|
| 16 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 17 |
+
|
| 18 |
+
encoding[:, 0::2] = torch.sin(position * div_term)
|
| 19 |
+
encoding[:, 1::2] = torch.cos(position * div_term)
|
| 20 |
+
encoding = encoding.unsqueeze(0)
|
| 21 |
+
|
| 22 |
+
self.register_buffer('encoding', encoding, persistent=False)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
"""Add positional encoding to input tensor."""
|
| 26 |
+
return x + self.encoding[:, :x.size(1)]
|
| 27 |
+
|
| 28 |
+
class NicheformerPreTrainedModel(PreTrainedModel):
|
| 29 |
+
"""Base class for Nicheformer models."""
|
| 30 |
+
|
| 31 |
+
config_class = NicheformerConfig
|
| 32 |
+
base_model_prefix = "nicheformer"
|
| 33 |
+
supports_gradient_checkpointing = True
|
| 34 |
+
|
| 35 |
+
def _init_weights(self, module):
|
| 36 |
+
if isinstance(module, nn.Linear):
|
| 37 |
+
nn.init.xavier_normal_(module.weight)
|
| 38 |
+
if module.bias is not None:
|
| 39 |
+
nn.init.zeros_(module.bias)
|
| 40 |
+
|
| 41 |
+
class NicheformerModel(NicheformerPreTrainedModel):
|
| 42 |
+
def __init__(self, config: NicheformerConfig):
|
| 43 |
+
super().__init__(config)
|
| 44 |
+
|
| 45 |
+
# Core transformer components
|
| 46 |
+
self.encoder_layer = nn.TransformerEncoderLayer(
|
| 47 |
+
d_model=config.dim_model,
|
| 48 |
+
nhead=config.nheads,
|
| 49 |
+
dim_feedforward=config.dim_feedforward,
|
| 50 |
+
batch_first=config.batch_first,
|
| 51 |
+
dropout=config.dropout,
|
| 52 |
+
layer_norm_eps=1e-12
|
| 53 |
+
)
|
| 54 |
+
self.encoder = nn.TransformerEncoder(
|
| 55 |
+
encoder_layer=self.encoder_layer,
|
| 56 |
+
num_layers=config.nlayers,
|
| 57 |
+
enable_nested_tensor=False
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Embedding layers
|
| 61 |
+
self.embeddings = nn.Embedding(
|
| 62 |
+
num_embeddings=config.n_tokens+5,
|
| 63 |
+
embedding_dim=config.dim_model,
|
| 64 |
+
padding_idx=1
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if config.learnable_pe:
|
| 68 |
+
self.positional_embedding = nn.Embedding(
|
| 69 |
+
num_embeddings=config.context_length,
|
| 70 |
+
embedding_dim=config.dim_model
|
| 71 |
+
)
|
| 72 |
+
self.dropout = nn.Dropout(p=config.dropout)
|
| 73 |
+
self.register_buffer('pos', torch.arange(0, config.context_length, dtype=torch.long))
|
| 74 |
+
else:
|
| 75 |
+
self.positional_embedding = PositionalEncoding(
|
| 76 |
+
d_model=config.dim_model,
|
| 77 |
+
max_seq_len=config.context_length
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Initialize weights
|
| 81 |
+
self.post_init()
|
| 82 |
+
|
| 83 |
+
def forward(self, input_ids, attention_mask=None):
|
| 84 |
+
token_embedding = self.embeddings(input_ids)
|
| 85 |
+
|
| 86 |
+
if self.config.learnable_pe:
|
| 87 |
+
pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device))
|
| 88 |
+
embeddings = self.dropout(token_embedding + pos_embedding)
|
| 89 |
+
else:
|
| 90 |
+
embeddings = self.positional_embedding(token_embedding)
|
| 91 |
+
|
| 92 |
+
# Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask
|
| 93 |
+
# True indicates positions that will be masked
|
| 94 |
+
if attention_mask is not None:
|
| 95 |
+
attention_mask = ~attention_mask.bool()
|
| 96 |
+
|
| 97 |
+
transformer_output = self.encoder(
|
| 98 |
+
embeddings,
|
| 99 |
+
src_key_padding_mask=attention_mask if attention_mask is not None else None,
|
| 100 |
+
is_causal=False
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
return transformer_output
|
| 104 |
+
|
| 105 |
+
class NicheformerForMaskedLM(NicheformerPreTrainedModel):
|
| 106 |
+
def __init__(self, config: NicheformerConfig):
|
| 107 |
+
super().__init__(config)
|
| 108 |
+
|
| 109 |
+
self.nicheformer = NicheformerModel(config)
|
| 110 |
+
self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False)
|
| 111 |
+
self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens))
|
| 112 |
+
|
| 113 |
+
# Initialize weights
|
| 114 |
+
self.post_init()
|
| 115 |
+
|
| 116 |
+
def forward(
|
| 117 |
+
self,
|
| 118 |
+
input_ids=None,
|
| 119 |
+
attention_mask=None,
|
| 120 |
+
labels=None,
|
| 121 |
+
return_dict=None,
|
| 122 |
+
apply_masking=False,
|
| 123 |
+
):
|
| 124 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 125 |
+
|
| 126 |
+
# Apply masking if requested (typically during training)
|
| 127 |
+
if apply_masking:
|
| 128 |
+
batch = {
|
| 129 |
+
'input_ids': input_ids,
|
| 130 |
+
'attention_mask': attention_mask
|
| 131 |
+
}
|
| 132 |
+
masked_batch = complete_masking(batch, self.config.masking_p, self.config.n_tokens)
|
| 133 |
+
input_ids = masked_batch['masked_indices']
|
| 134 |
+
labels = masked_batch['input_ids'] # Original tokens become labels
|
| 135 |
+
mask = masked_batch['mask']
|
| 136 |
+
# Only compute loss on masked tokens and ensure labels are long
|
| 137 |
+
labels = torch.where(mask, labels, torch.tensor(-100, device=labels.device)).long()
|
| 138 |
+
|
| 139 |
+
transformer_output = self.nicheformer(
|
| 140 |
+
input_ids=input_ids,
|
| 141 |
+
attention_mask=attention_mask,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
prediction_scores = self.classifier_head(transformer_output)
|
| 145 |
+
|
| 146 |
+
masked_lm_loss = None
|
| 147 |
+
if labels is not None:
|
| 148 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 149 |
+
masked_lm_loss = loss_fct(
|
| 150 |
+
prediction_scores.view(-1, self.config.n_tokens),
|
| 151 |
+
labels.view(-1)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if not return_dict:
|
| 155 |
+
output = (prediction_scores,) + (transformer_output,)
|
| 156 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 157 |
+
|
| 158 |
+
return MaskedLMOutput(
|
| 159 |
+
loss=masked_lm_loss,
|
| 160 |
+
logits=prediction_scores,
|
| 161 |
+
hidden_states=transformer_output,
|
| 162 |
+
)
|
tokenization_nicheformer.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional, Union, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from transformers import PreTrainedTokenizer
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import torch
|
| 6 |
+
import anndata as ad
|
| 7 |
+
from scipy.sparse import issparse
|
| 8 |
+
import numba
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
# Token IDs must match exactly with the original implementation
|
| 13 |
+
PAD_TOKEN = 0
|
| 14 |
+
MASK_TOKEN = 1
|
| 15 |
+
CLS_TOKEN = 2
|
| 16 |
+
|
| 17 |
+
# These mappings preserve the exact token IDs from the original implementation
|
| 18 |
+
MODALITY_DICT = {
|
| 19 |
+
'dissociated': 3,
|
| 20 |
+
'spatial': 4,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
SPECIES_DICT = {
|
| 24 |
+
'human': 5,
|
| 25 |
+
'Homo sapiens': 5,
|
| 26 |
+
'Mus musculus': 6,
|
| 27 |
+
'mouse': 6,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
TECHNOLOGY_DICT = {
|
| 31 |
+
"merfish": 7,
|
| 32 |
+
"MERFISH": 7,
|
| 33 |
+
"cosmx": 8,
|
| 34 |
+
"NanoString digital spatial profiling": 8,
|
| 35 |
+
"Xenium": 9,
|
| 36 |
+
"10x 5' v2": 10,
|
| 37 |
+
"10x 3' v3": 11,
|
| 38 |
+
"10x 3' v2": 12,
|
| 39 |
+
"10x 5' v1": 13,
|
| 40 |
+
"10x 3' v1": 14,
|
| 41 |
+
"10x 3' transcription profiling": 15,
|
| 42 |
+
"10x transcription profiling": 15,
|
| 43 |
+
"10x 5' transcription profiling": 16,
|
| 44 |
+
"CITE-seq": 17,
|
| 45 |
+
"Smart-seq v4": 18,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def sf_normalize(X: np.ndarray) -> np.ndarray:
|
| 49 |
+
"""Size factor normalize to 10k counts."""
|
| 50 |
+
X = X.copy()
|
| 51 |
+
counts = np.array(X.sum(axis=1))
|
| 52 |
+
# avoid zero division error
|
| 53 |
+
counts += counts == 0.
|
| 54 |
+
# normalize to 10000 counts
|
| 55 |
+
scaling_factor = 10000. / counts
|
| 56 |
+
|
| 57 |
+
if issparse(X):
|
| 58 |
+
from scipy.sparse import sparsefuncs
|
| 59 |
+
sparsefuncs.inplace_row_scale(X, scaling_factor)
|
| 60 |
+
else:
|
| 61 |
+
np.multiply(X, scaling_factor.reshape((-1, 1)), out=X)
|
| 62 |
+
|
| 63 |
+
return X
|
| 64 |
+
|
| 65 |
+
@numba.jit(nopython=True, nogil=True)
|
| 66 |
+
def _sub_tokenize_data(x: np.ndarray, max_seq_len: int = -1, aux_tokens: int = 30) -> np.ndarray:
|
| 67 |
+
"""Tokenize the input gene vector."""
|
| 68 |
+
scores_final = np.empty((x.shape[0], max_seq_len if max_seq_len > 0 else x.shape[1]))
|
| 69 |
+
for i, cell in enumerate(x):
|
| 70 |
+
nonzero_mask = np.nonzero(cell)[0]
|
| 71 |
+
sorted_indices = nonzero_mask[np.argsort(-cell[nonzero_mask])][:max_seq_len]
|
| 72 |
+
sorted_indices = sorted_indices + aux_tokens
|
| 73 |
+
if max_seq_len:
|
| 74 |
+
scores = np.zeros(max_seq_len, dtype=np.int32)
|
| 75 |
+
else:
|
| 76 |
+
scores = np.zeros_like(cell, dtype=np.int32)
|
| 77 |
+
scores[:len(sorted_indices)] = sorted_indices.astype(np.int32)
|
| 78 |
+
scores_final[i, :] = scores
|
| 79 |
+
return scores_final
|
| 80 |
+
|
| 81 |
+
class NicheformerTokenizer(PreTrainedTokenizer):
|
| 82 |
+
"""Tokenizer for Nicheformer that handles single-cell data."""
|
| 83 |
+
|
| 84 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 85 |
+
vocab_files_names = {"vocab_file": "vocab.json"}
|
| 86 |
+
|
| 87 |
+
modality_dict = MODALITY_DICT
|
| 88 |
+
species_dict = SPECIES_DICT
|
| 89 |
+
technology_dict = TECHNOLOGY_DICT
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
vocab_file=None,
|
| 94 |
+
max_length: int = 1500,
|
| 95 |
+
aux_tokens: int = 30,
|
| 96 |
+
median_counts_per_gene: Optional[np.ndarray] = None,
|
| 97 |
+
gene_names: Optional[List[str]] = None,
|
| 98 |
+
**kwargs
|
| 99 |
+
):
|
| 100 |
+
# Initialize base vocabulary
|
| 101 |
+
self._vocabulary = {
|
| 102 |
+
"[PAD]": PAD_TOKEN,
|
| 103 |
+
"[MASK]": MASK_TOKEN,
|
| 104 |
+
"[CLS]": CLS_TOKEN,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
if vocab_file is not None:
|
| 108 |
+
with open(vocab_file, 'r') as f:
|
| 109 |
+
self._vocabulary.update(json.load(f))
|
| 110 |
+
else:
|
| 111 |
+
# Add modality tokens
|
| 112 |
+
for name, idx in self.modality_dict.items():
|
| 113 |
+
self._vocabulary[f"[MODALITY_{name}]"] = idx
|
| 114 |
+
# Add species tokens
|
| 115 |
+
for name, idx in self.species_dict.items():
|
| 116 |
+
if name in ["Homo sapiens", "Mus musculus"]:
|
| 117 |
+
continue # Skip redundant names
|
| 118 |
+
self._vocabulary[f"[SPECIES_{name}]"] = idx
|
| 119 |
+
# Add technology tokens
|
| 120 |
+
for name, idx in self.technology_dict.items():
|
| 121 |
+
if name in ["MERFISH", "10x transcription profiling"]:
|
| 122 |
+
continue # Skip redundant names
|
| 123 |
+
clean_name = name.lower().replace(" ", "_").replace("'", "_")
|
| 124 |
+
self._vocabulary[f"[TECH_{clean_name}]"] = idx
|
| 125 |
+
|
| 126 |
+
# Add gene tokens if provided
|
| 127 |
+
if gene_names is not None:
|
| 128 |
+
for i, gene in enumerate(gene_names):
|
| 129 |
+
self._vocabulary[gene] = i + aux_tokens
|
| 130 |
+
# Save vocabulary
|
| 131 |
+
os.makedirs('to_hf', exist_ok=True)
|
| 132 |
+
with open('to_hf/vocab.json', 'w') as f:
|
| 133 |
+
json.dump(self._vocabulary, f, indent=4)
|
| 134 |
+
|
| 135 |
+
super().__init__(**kwargs)
|
| 136 |
+
|
| 137 |
+
self.max_length = max_length
|
| 138 |
+
self.aux_tokens = aux_tokens
|
| 139 |
+
self.median_counts_per_gene = median_counts_per_gene
|
| 140 |
+
self.gene_names = gene_names
|
| 141 |
+
|
| 142 |
+
# Set up special token mappings
|
| 143 |
+
self._pad_token = "[PAD]"
|
| 144 |
+
self._mask_token = "[MASK]"
|
| 145 |
+
self._cls_token = "[CLS]"
|
| 146 |
+
|
| 147 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 148 |
+
"""Returns the vocabulary mapping."""
|
| 149 |
+
return self._vocabulary.copy()
|
| 150 |
+
|
| 151 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 152 |
+
"""Tokenize text input."""
|
| 153 |
+
# This tokenizer doesn't handle text input directly
|
| 154 |
+
raise NotImplementedError("This tokenizer only works with gene expression data")
|
| 155 |
+
|
| 156 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 157 |
+
"""Convert token to ID."""
|
| 158 |
+
# First check special token mappings
|
| 159 |
+
if token in self.modality_dict:
|
| 160 |
+
return self.modality_dict[token]
|
| 161 |
+
if token in self.species_dict:
|
| 162 |
+
return self.species_dict[token]
|
| 163 |
+
if token in self.technology_dict:
|
| 164 |
+
return self.technology_dict[token]
|
| 165 |
+
# Then check vocabulary
|
| 166 |
+
return self._vocabulary.get(token, self._vocabulary["[PAD]"])
|
| 167 |
+
|
| 168 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 169 |
+
"""Convert ID to token."""
|
| 170 |
+
# First check special token mappings
|
| 171 |
+
for token, idx in self.modality_dict.items():
|
| 172 |
+
if idx == index:
|
| 173 |
+
return token
|
| 174 |
+
for token, idx in self.species_dict.items():
|
| 175 |
+
if idx == index:
|
| 176 |
+
return token
|
| 177 |
+
for token, idx in self.technology_dict.items():
|
| 178 |
+
if idx == index:
|
| 179 |
+
return token
|
| 180 |
+
# Then check vocabulary
|
| 181 |
+
for token, idx in self._vocabulary.items():
|
| 182 |
+
if idx == index:
|
| 183 |
+
return token
|
| 184 |
+
return "[PAD]"
|
| 185 |
+
|
| 186 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 187 |
+
"""Save the vocabulary to a file."""
|
| 188 |
+
vocab_file = os.path.join(
|
| 189 |
+
save_directory,
|
| 190 |
+
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 194 |
+
json.dump(self._vocabulary, f, ensure_ascii=False)
|
| 195 |
+
|
| 196 |
+
return (vocab_file,)
|
| 197 |
+
|
| 198 |
+
def _tokenize_gene_expression(self, x: np.ndarray) -> np.ndarray:
|
| 199 |
+
"""Tokenize gene expression matrix.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x: Gene expression matrix (cells x genes)
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Tokenized matrix
|
| 206 |
+
"""
|
| 207 |
+
# Handle sparse input
|
| 208 |
+
if issparse(x):
|
| 209 |
+
x = x.toarray()
|
| 210 |
+
|
| 211 |
+
# Normalize and scale
|
| 212 |
+
x = np.nan_to_num(x)
|
| 213 |
+
x = sf_normalize(x)
|
| 214 |
+
if self.median_counts_per_gene is not None:
|
| 215 |
+
median_counts = self.median_counts_per_gene.copy()
|
| 216 |
+
median_counts += median_counts == 0
|
| 217 |
+
x = x / median_counts.reshape((1, -1))
|
| 218 |
+
|
| 219 |
+
# Convert to tokens
|
| 220 |
+
tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
|
| 221 |
+
|
| 222 |
+
return tokens.astype(np.int32)
|
| 223 |
+
|
| 224 |
+
def __call__(
|
| 225 |
+
self,
|
| 226 |
+
adata: Optional[ad.AnnData] = None,
|
| 227 |
+
gene_expression: Optional[Union[np.ndarray, List[float]]] = None,
|
| 228 |
+
modality: Optional[str] = None,
|
| 229 |
+
species: Optional[str] = None,
|
| 230 |
+
technology: Optional[str] = None,
|
| 231 |
+
**kwargs
|
| 232 |
+
) -> Dict[str, torch.Tensor]:
|
| 233 |
+
"""Convert inputs to model inputs.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
adata: AnnData object
|
| 237 |
+
gene_expression: Gene expression matrix if not using AnnData
|
| 238 |
+
modality: Modality type
|
| 239 |
+
species: Species type
|
| 240 |
+
technology: Technology/assay type
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dictionary with model inputs
|
| 244 |
+
"""
|
| 245 |
+
if adata is not None:
|
| 246 |
+
# Get expression matrix
|
| 247 |
+
if issparse(adata.X):
|
| 248 |
+
x = adata.X.toarray()
|
| 249 |
+
else:
|
| 250 |
+
x = adata.X
|
| 251 |
+
|
| 252 |
+
# Get metadata for each cell if not provided
|
| 253 |
+
if modality is None and 'modality' in adata.obs:
|
| 254 |
+
modality = adata.obs['modality'].values
|
| 255 |
+
if species is None and 'specie' in adata.obs:
|
| 256 |
+
species = adata.obs['specie'].values
|
| 257 |
+
if technology is None and 'assay' in adata.obs:
|
| 258 |
+
technology = adata.obs['assay'].values
|
| 259 |
+
|
| 260 |
+
elif gene_expression is not None:
|
| 261 |
+
x = np.array(gene_expression)
|
| 262 |
+
if len(x.shape) == 1:
|
| 263 |
+
x = x.reshape(1, -1)
|
| 264 |
+
# For single gene expression input, convert scalar metadata to arrays
|
| 265 |
+
if modality is not None:
|
| 266 |
+
modality = np.array([modality])
|
| 267 |
+
if species is not None:
|
| 268 |
+
species = np.array([species])
|
| 269 |
+
if technology is not None:
|
| 270 |
+
technology = np.array([technology])
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError("Either adata or gene_expression must be provided")
|
| 273 |
+
|
| 274 |
+
# Tokenize gene expression
|
| 275 |
+
token_ids = self._tokenize_gene_expression(x)
|
| 276 |
+
n_cells = token_ids.shape[0]
|
| 277 |
+
|
| 278 |
+
# Add special tokens for each cell
|
| 279 |
+
special_tokens = np.zeros((n_cells, 3), dtype=np.int32) # 3 for modality, species, technology
|
| 280 |
+
special_token_mask = np.zeros((n_cells, 3), dtype=bool) # Track which tokens are actually present
|
| 281 |
+
|
| 282 |
+
if modality is not None:
|
| 283 |
+
special_tokens[:, 0] = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality]
|
| 284 |
+
special_token_mask[:, 0] = True
|
| 285 |
+
|
| 286 |
+
if species is not None:
|
| 287 |
+
special_tokens[:, 1] = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species]
|
| 288 |
+
special_token_mask[:, 1] = True
|
| 289 |
+
|
| 290 |
+
if technology is not None:
|
| 291 |
+
special_tokens[:, 2] = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology]
|
| 292 |
+
special_token_mask[:, 2] = True
|
| 293 |
+
|
| 294 |
+
# Only keep the special tokens that are present (have True in mask)
|
| 295 |
+
special_tokens = special_tokens[:, special_token_mask[0]]
|
| 296 |
+
|
| 297 |
+
if special_tokens.size > 0:
|
| 298 |
+
token_ids = np.concatenate([special_tokens, token_ids[:, :(self.max_length - special_tokens.shape[1])]], axis=1)
|
| 299 |
+
|
| 300 |
+
# Create attention mask
|
| 301 |
+
attention_mask = (token_ids != self._vocabulary["[PAD]"])
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"input_ids": torch.tensor(token_ids, dtype=torch.long),
|
| 305 |
+
"attention_mask": torch.tensor(attention_mask)
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def get_vocab_size(self) -> int:
|
| 309 |
+
"""Get vocabulary size."""
|
| 310 |
+
if self.gene_names is not None:
|
| 311 |
+
return len(self.gene_names) + self.aux_tokens
|
| 312 |
+
return max(
|
| 313 |
+
max(self.modality_dict.values()),
|
| 314 |
+
max(self.species_dict.values()),
|
| 315 |
+
max(self.technology_dict.values())
|
| 316 |
+
) + 1
|
| 317 |
+
|
| 318 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 319 |
+
"""Convert a sequence of tokens to a string. Not used for gene expression."""
|
| 320 |
+
raise NotImplementedError("This tokenizer only works with gene expression data")
|
| 321 |
+
|
| 322 |
+
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
|
| 323 |
+
"""Build model inputs from a sequence by adding special tokens."""
|
| 324 |
+
# For gene expression data, special tokens are handled in __call__
|
| 325 |
+
return token_ids_0
|
| 326 |
+
|
| 327 |
+
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
|
| 328 |
+
"""Get list where entries are [1] if a token is [special] else [0]."""
|
| 329 |
+
# Consider tokens < aux_tokens as special
|
| 330 |
+
return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0]
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|