File size: 5,374 Bytes
a3a0e3e
a09ef8b
a3a0e3e
a09ef8b
2e9ef7f
c2ab9c5
 
 
 
 
 
 
 
 
2e9ef7f
a09ef8b
 
2e9ef7f
 
a09ef8b
 
 
2e9ef7f
 
a09ef8b
66090a5
 
5fb0b4c
 
66090a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb0b4c
 
 
66090a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
---
license: cc-by-nc-nd-4.0
---
# Overview

This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on *full* scRNA-seq cells.
Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics. 
We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences". 
For more details, we refer to the paper linked below. 
This model was trained on the immune tissue dataset from [Domínguez et al.](https://www.science.org/doi/10.1126/science.abl5197) 
using 8 A100 40GB GPUs for approximately 20 hours on the following tasks:
1. conditional cell generation
2. unconditional cell generation
3. cell type prediction

## Cell2Sentence Links:  
GitHub: <https://github.com/vandijklab/cell2sentence-ft>  
Paper: <https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3>

## Pythia Links:  
GitHub: <https://github.com/EleutherAI/pythia>  
Paper: <https://arxiv.org/abs/2304.01373>  
Hugging Face: <https://huggingface.co/EleutherAI/pythia-160m>

# Sample Code

We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes. 
In order to generate full cells, the `max_length` generation parameter should be changed to 9200. 
However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required. 
Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt. 
We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.

```
import json
import re
from collections import Counter
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def post_process_generated_cell_sentences(
    cell_sentence: str, 
    gene_dictionary: List
):
    """
    Post-processing function for generated cell sentences. 
    Invalid genes are removed and ranks of duplicated genes are averaged.

    Arguments:
        cell_sentence:              generated cell sentence string
        gene_dictionary:            list of gene vocabulary (all uppercase)

    Returns:
        post_processed_sentence:    generated cell sentence after post processing steps
    """
    generated_gene_names = cell_sentence.split(" ")
    generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]

    #--- Remove nonsense genes ---#
    generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]

    #--- Average ranks ---#
    gene_name_to_occurrences = Counter(generated_gene_names)  # get mapping of gene name --> number of occurrences
    post_processed_sentence = generated_gene_names.copy()  # copy of generated gene list

    for gene_name in gene_name_to_occurrences:
        if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
            # Find positions of all occurrences of duplicated generated gene in list
            # Note: using post_processed_sentence here; since duplicates are being removed, list will be
            #   getting shorter. Getting indices in original list will no longer be accurate positions
            occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
            average_position = int(sum(occurrence_positions) / len(occurrence_positions))

            # Remove occurrences
            post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]

            # Reinsert gene_name at average position
            post_processed_sentence.insert(average_position, gene_name)
    
    return post_processed_sentence

genes_path = "pbmc_vocab.json"

with open(vocab_path, "r") as f:
    gene_dictionary = json.load(f)

model_name = "vandijklab/pythia-160m-c2s"

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, 
        attn_implementation="flash_attention_2"
        ).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)

cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."

# Prompts for other forms a generation.
# ucg = "Display a cell's genes by expression level, in descending order."
# cellsentence = "CELL_SENTENCE"
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. "
#  + cellsentence +
#  "Name the cell type connected to these genes, ranked from highest to lowest expression."

tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))

with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        do_sample=True,
        max_length=1024,
        top_k=50,
        top_p=0.95,
    )

output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)
```