ctheodoris commited on
Commit
ea428cb
1 Parent(s): eb2a04b

move dicts to init

Browse files
geneformer/__init__.py CHANGED
@@ -1,4 +1,10 @@
1
  # ruff: noqa: F401
 
 
 
 
 
 
2
  from . import (
3
  collator_for_classification,
4
  emb_extractor,
@@ -18,4 +24,4 @@ from .pretrainer import GeneformerPretrainer
18
  from .tokenizer import TranscriptomeTokenizer
19
 
20
  from . import classifier # noqa # isort:skip
21
- from .classifier import Classifier # noqa # isort:skip
 
1
  # ruff: noqa: F401
2
+ from pathlib import Path
3
+
4
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
7
+
8
  from . import (
9
  collator_for_classification,
10
  emb_extractor,
 
24
  from .tokenizer import TranscriptomeTokenizer
25
 
26
  from . import classifier # noqa # isort:skip
27
+ from .classifier import Classifier # noqa # isort:skip
geneformer/classifier.py CHANGED
@@ -61,7 +61,7 @@ from . import DataCollatorForCellClassification, DataCollatorForGeneClassificati
61
  from . import classifier_utils as cu
62
  from . import evaluation_utils as eu
63
  from . import perturber_utils as pu
64
- from .tokenizer import TOKEN_DICTIONARY_FILE
65
 
66
  sns.set()
67
 
 
61
  from . import classifier_utils as cu
62
  from . import evaluation_utils as eu
63
  from . import perturber_utils as pu
64
+ from . import TOKEN_DICTIONARY_FILE
65
 
66
  sns.set()
67
 
geneformer/collator_for_classification.py CHANGED
@@ -4,6 +4,7 @@ Geneformer collator for gene and cell classification.
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
  import numpy as np
 
7
  import torch
8
  import warnings
9
  from enum import Enum
@@ -17,7 +18,11 @@ from transformers import (
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
20
- from .pretrainer import token_dictionary
 
 
 
 
21
 
22
  EncodedInput = List[int]
23
  logger = logging.get_logger(__name__)
 
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
  import numpy as np
7
+ import pickle
8
  import torch
9
  import warnings
10
  from enum import Enum
 
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
21
+ from . import TOKEN_DICTIONARY_FILE
22
+
23
+ # load token dictionary (Ensembl IDs:token)
24
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
25
+ token_dictionary = pickle.load(f)
26
 
27
  EncodedInput = List[int]
28
  logger = logging.get_logger(__name__)
geneformer/emb_extractor.py CHANGED
@@ -25,7 +25,7 @@ from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
27
  from . import perturber_utils as pu
28
- from .tokenizer import TOKEN_DICTIONARY_FILE
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
25
  from tqdm.auto import trange
26
 
27
  from . import perturber_utils as pu
28
+ from . import TOKEN_DICTIONARY_FILE
29
 
30
  logger = logging.getLogger(__name__)
31
 
geneformer/evaluation_utils.py CHANGED
@@ -21,7 +21,7 @@ from sklearn.metrics import (
21
  from tqdm.auto import trange
22
 
23
  from .emb_extractor import make_colorbar
24
- from .tokenizer import TOKEN_DICTIONARY_FILE
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
21
  from tqdm.auto import trange
22
 
23
  from .emb_extractor import make_colorbar
24
+ from . import TOKEN_DICTIONARY_FILE
25
 
26
  logger = logging.getLogger(__name__)
27
 
geneformer/in_silico_perturber.py CHANGED
@@ -38,21 +38,17 @@ import logging
38
  import os
39
  import pickle
40
  from collections import defaultdict
41
- from typing import List
42
  from multiprocess import set_start_method
43
 
44
- import seaborn as sns
45
  import torch
46
- from datasets import Dataset
47
  from tqdm.auto import trange
48
 
49
  from . import perturber_utils as pu
50
  from .emb_extractor import get_embs
51
- from .perturber_utils import TOKEN_DICTIONARY_FILE
52
-
53
-
54
- sns.set()
55
 
 
56
 
57
  logger = logging.getLogger(__name__)
58
 
 
38
  import os
39
  import pickle
40
  from collections import defaultdict
 
41
  from multiprocess import set_start_method
42
 
 
43
  import torch
44
+ from datasets import Dataset, disable_progress_bars
45
  from tqdm.auto import trange
46
 
47
  from . import perturber_utils as pu
48
  from .emb_extractor import get_embs
49
+ from . import TOKEN_DICTIONARY_FILE
 
 
 
50
 
51
+ disable_progress_bars()
52
 
53
  logger = logging.getLogger(__name__)
54
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -38,9 +38,7 @@ from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
- from .tokenizer import TOKEN_DICTIONARY_FILE
42
-
43
- GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
44
 
45
  logger = logging.getLogger(__name__)
46
 
@@ -673,7 +671,7 @@ class InSilicoPerturberStats:
673
  cell_states_to_model=None,
674
  pickle_suffix="_raw.pickle",
675
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
676
- gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
677
  ):
678
  """
679
  Initialize in silico perturber stats generator.
 
38
  from tqdm.auto import tqdm, trange
39
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
+ from . import TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
 
 
42
 
43
  logger = logging.getLogger(__name__)
44
 
 
671
  cell_states_to_model=None,
672
  pickle_suffix="_raw.pickle",
673
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
674
+ gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
675
  ):
676
  """
677
  Initialize in silico perturber stats generator.
geneformer/perturber_utils.py CHANGED
@@ -18,13 +18,9 @@ from transformers import (
18
  BertForTokenClassification,
19
  )
20
 
21
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
22
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
23
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
24
 
25
 
26
- sns.set()
27
-
28
  logger = logging.getLogger(__name__)
29
 
30
 
 
18
  BertForTokenClassification,
19
  )
20
 
21
+ from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
 
 
22
 
23
 
 
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
geneformer/pretrainer.py CHANGED
@@ -32,7 +32,7 @@ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
- from .tokenizer import TOKEN_DICTIONARY_FILE
36
 
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
 
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
+ from . import TOKEN_DICTIONARY_FILE
36
 
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
geneformer/tokenizer.py CHANGED
@@ -52,7 +52,7 @@ import loompy as lp # noqa
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
- from .perturber_utils import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
56
 
57
 
58
  def rank_genes(gene_vector, gene_tokens):
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
+ from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
56
 
57
 
58
  def rank_genes(gene_vector, gene_tokens):