pythia-160m-c2s / README.md
dflevine13's picture
update readme
c2ab9c5 verified
|
raw
history blame
5.37 kB
metadata
license: cc-by-nc-nd-4.0

Overview

This is the the Pythia-160m model developed by EleutherAI fine-tuned using Cell2Sentence on full scRNA-seq cells. Cell2Sentence is a novel method for adapting large language models to single-cell transcriptomics. We transform single-cell RNA sequencing data into sequences of gene names ordered by expression level, termed "cell sentences". For more details, we refer to the paper linked below. This model was trained on the immune tissue dataset from Domínguez et al. using 8 A100 40GB GPUs for approximately 20 hours on the following tasks:

  1. conditional cell generation
  2. unconditional cell generation
  3. cell type prediction

Cell2Sentence Links:

GitHub: https://github.com/vandijklab/cell2sentence-ft
Paper: https://www.biorxiv.org/content/10.1101/2023.09.11.557287v3

Pythia Links:

GitHub: https://github.com/EleutherAI/pythia
Paper: https://arxiv.org/abs/2304.01373
Hugging Face: https://huggingface.co/EleutherAI/pythia-160m

Sample Code

We provide an example of how to use the model to conditionally generate a cell equipped with a post-processing function to remove duplicate and invalid genes. In order to generate full cells, the max_length generation parameter should be changed to 9200. However, we recommend using an A100 GPU for inference speed and memory capacity if full cell generation is required. Unconditional cell generation and cell type prediction prompts are included as well, but we do not include an example cell sentence to format the prompt. We refer to the paper and GitHub repository for instructions on how to transform expression vectors into cell sentences.

import json
import re
from collections import Counter
from typing import List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def post_process_generated_cell_sentences(
    cell_sentence: str, 
    gene_dictionary: List
):
    """
    Post-processing function for generated cell sentences. 
    Invalid genes are removed and ranks of duplicated genes are averaged.

    Arguments:
        cell_sentence:              generated cell sentence string
        gene_dictionary:            list of gene vocabulary (all uppercase)

    Returns:
        post_processed_sentence:    generated cell sentence after post processing steps
    """
    generated_gene_names = cell_sentence.split(" ")
    generated_gene_names = [generated_gene.upper() for generated_gene in generated_gene_names]

    #--- Remove nonsense genes ---#
    generated_gene_names = [gene_name for gene_name in generated_gene_names if gene_name in gene_dictionary]

    #--- Average ranks ---#
    gene_name_to_occurrences = Counter(generated_gene_names)  # get mapping of gene name --> number of occurrences
    post_processed_sentence = generated_gene_names.copy()  # copy of generated gene list

    for gene_name in gene_name_to_occurrences:
        if gene_name_to_occurrences[gene_name] > 1 and gene_name != replace_nonsense_string:
            # Find positions of all occurrences of duplicated generated gene in list
            # Note: using post_processed_sentence here; since duplicates are being removed, list will be
            #   getting shorter. Getting indices in original list will no longer be accurate positions
            occurrence_positions = [idx for idx, elem in enumerate(post_processed_sentence) if elem == gene_name]
            average_position = int(sum(occurrence_positions) / len(occurrence_positions))

            # Remove occurrences
            post_processed_sentence = [elem for elem in post_processed_sentence if elem != gene_name]

            # Reinsert gene_name at average position
            post_processed_sentence.insert(average_position, gene_name)
    
    return post_processed_sentence

genes_path = "pbmc_vocab.json"

with open(vocab_path, "r") as f:
    gene_dictionary = json.load(f)

model_name = "vandijklab/pythia-160m-c2s"

model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16, 
        attn_implementation="flash_attention_2"
        ).to(torch.device("cuda"))
tokenizer = AutoTokenizer.from_pretrained(model_name)

cell_type = "T Cell"
ccg = f"Enumerate the genes in a {cell_type} cell with nonzero expression, from highest to lowest."

# Prompts for other forms a generation.
# ucg = "Display a cell's genes by expression level, in descending order."
# cellsentence = "CELL_SENTENCE"
# ctp = "Identify the cell type most likely associated with these highly expressed genes listed in descending order. "
#  + cellsentence +
#  "Name the cell type connected to these genes, ranked from highest to lowest expression."

tokens = tokenizer(ccg, return_tensors='pt')
input_ids = tokens['input_ids'].to(torch.device("cuda"))
attention_mask = tokens['attention_mask'].to(torch.device("cuda"))

with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        do_sample=True,
        max_length=1024,
        top_k=50,
        top_p=0.95,
    )

output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
cell_sentence = "".join(re.split(r"\?|\.|:", output_text)[1:]).strip()
processed_genes = post_process_generated_cell_sentences(cell_sentence, gene_dictionary)