Isoformer / README.md
isoformer-anonymous's picture
Update README.md
5e9fe45 verified
|
raw
history blame
1.42 kB

A small snippet of code is given here in order to retrieve embeddings and gene expression predictions given a DNA, RNA and protein sequence.

from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np
import torch

# Import the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("isoformer-anonymous/Isoformer", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("isoformer-anonymous/Isoformer",trust_remote_code=True)

protein_sequences = ["RSRSRSRSRSRSRSRSRSRSRL" * 9]
rna_sequences = ["ATTCCGGTTTTCA" * 9]
sequence_length = 196_608
rng = np.random.default_rng(seed=0)
dna_sequences = ["".join(rng.choice(list("ATCGN"), size=(sequence_length,)))]

torch_tokens = tokenizer(
    dna_input=dna_sequences, rna_input=rna_sequences, protein_input=protein_sequences
)
dna_torch_tokens = torch.tensor(torch_tokens[0]["input_ids"])
rna_torch_tokens = torch.tensor(torch_tokens[1]["input_ids"])
protein_torch_tokens = torch.tensor(torch_tokens[2]["input_ids"])

torch_output = model.forward(
    tensor_dna=dna_torch_tokens,
    tensor_rna=rna_torch_tokens,
    tensor_protein=protein_torch_tokens,
    attention_mask_rna=rna_torch_tokens != 1,
    attention_mask_protein=protein_torch_tokens != 1,
)

print(f"Gene expression predictions: {torch_output['gene_expression_predictions']}")
print(f"Final DNA embedding: {torch_output['final_dna_embeddings']}")