Edit model card

Plant foundation DNA large language models

The plant DNA large language models (LLMs) contain a series of foundation models based on different model architectures, which are pre-trained on various plant reference genomes.
All the models have a comparable model size between 90 MB and 150 MB, BPE tokenizer is used for tokenization and 8000 tokens are included in the vocabulary.

Developed by: zhangtaolab

Model Sources

Architecture

The model is trained based on the Google Gemma model with modified config and tokenizer specific for DNA sequence.

How to use

Install the runtime library first:

pip install transformers

Here is a simple code for inference:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = 'plant-dnagemma'
# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(f'zhangtaolab/{model_name}', trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(f'zhangtaolab/{model_name}', trust_remote_code=True)

# example sequence and tokenization
sequences = ['ATATACGGCCGNC','GGGTATCGCTTCCGAC']
tokens = tokenizer(sequences,padding="longest")['input_ids']
print(f"Tokenzied sequence: {tokenizer.batch_decode(tokens)}")

# inference
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
inputs = tokenizer(sequences, truncation=True, padding='max_length', max_length=512, 
                   return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outs = model(
    **inputs,
    output_hidden_states=True
)

# get the final layer embeddings and prediction logits
embeddings = outs['hidden_states'][-1].detach().numpy()
logits = outs['logits'].detach().numpy()

Training data

We use CausalLM method to pre-train the model, the tokenized sequence have a maximum length of 512.
Detailed training procedure can be found in our manuscript.

Hardware

Model was pre-trained on a NVIDIA RTX4090 GPU (24 GB).

Downloads last month
2
Safetensors
Model size
152M params
Tensor type
F32
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including zhangtaolab/plant-dnagemma