Tokenization with model alignment
Browse files- tokenization_nicheformer.py +22 -0
tokenization_nicheformer.py
CHANGED
|
@@ -8,6 +8,7 @@ from scipy.sparse import issparse
|
|
| 8 |
import numba
|
| 9 |
import os
|
| 10 |
import json
|
|
|
|
| 11 |
|
| 12 |
# Token IDs must match exactly with the original implementation
|
| 13 |
PAD_TOKEN = 0
|
|
@@ -88,6 +89,19 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 88 |
species_dict = SPECIES_DICT
|
| 89 |
technology_dict = TECHNOLOGY_DICT
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def __init__(
|
| 92 |
self,
|
| 93 |
vocab_file=None,
|
|
@@ -138,6 +152,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 138 |
self.aux_tokens = aux_tokens
|
| 139 |
self.median_counts_per_gene = median_counts_per_gene
|
| 140 |
self.gene_names = gene_names
|
|
|
|
| 141 |
|
| 142 |
# Set up special token mappings
|
| 143 |
self._pad_token = "[PAD]"
|
|
@@ -243,6 +258,13 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 243 |
Dictionary with model inputs
|
| 244 |
"""
|
| 245 |
if adata is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
# Get expression matrix
|
| 247 |
if issparse(adata.X):
|
| 248 |
x = adata.X.toarray()
|
|
|
|
| 8 |
import numba
|
| 9 |
import os
|
| 10 |
import json
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
|
| 13 |
# Token IDs must match exactly with the original implementation
|
| 14 |
PAD_TOKEN = 0
|
|
|
|
| 89 |
species_dict = SPECIES_DICT
|
| 90 |
technology_dict = TECHNOLOGY_DICT
|
| 91 |
|
| 92 |
+
def _load_reference_model(self):
|
| 93 |
+
"""Load reference model for gene alignment."""
|
| 94 |
+
try:
|
| 95 |
+
# Get the model name or path from the tokenizer
|
| 96 |
+
repo_id = self.name_or_path if hasattr(self, "name_or_path") else "aletlvl/Nicheformer"
|
| 97 |
+
|
| 98 |
+
# Download the reference model if not already cached
|
| 99 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model.h5ad")
|
| 100 |
+
return ad.read_h5ad(model_path)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Warning: Could not load reference model: {e}")
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
def __init__(
|
| 106 |
self,
|
| 107 |
vocab_file=None,
|
|
|
|
| 152 |
self.aux_tokens = aux_tokens
|
| 153 |
self.median_counts_per_gene = median_counts_per_gene
|
| 154 |
self.gene_names = gene_names
|
| 155 |
+
self.name_or_path = kwargs.get('name_or_path', 'aletlvl/Nicheformer')
|
| 156 |
|
| 157 |
# Set up special token mappings
|
| 158 |
self._pad_token = "[PAD]"
|
|
|
|
| 258 |
Dictionary with model inputs
|
| 259 |
"""
|
| 260 |
if adata is not None:
|
| 261 |
+
# Align with reference model if needed
|
| 262 |
+
reference_model = self._load_reference_model()
|
| 263 |
+
if reference_model is not None:
|
| 264 |
+
# Concatenate and then remove the reference
|
| 265 |
+
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
| 266 |
+
adata = adata[1:]
|
| 267 |
+
|
| 268 |
# Get expression matrix
|
| 269 |
if issparse(adata.X):
|
| 270 |
x = adata.X.toarray()
|