|
--- |
|
tags: |
|
- dnabert |
|
- bacteria |
|
- kmer |
|
- classification |
|
- sequence-modeling |
|
- DNA |
|
library_name: transformers |
|
--- |
|
|
|
# BacteriaCDS-DNABERT-K6-89M |
|
|
|
This model, `BacteriaCDS-DNABERT-K6-89M`, is a **DNA sequence classifier** based on **DNABERT** trained for **coding sequence (CDS) classification** in bacterial genomes. It operates on **6-mer tokenized sequences** and was fine-tuned using **89M trainable parameters**. |
|
|
|
## Model Details |
|
- **Base Model:** DNABERT |
|
- **Task:** Bacterial CDS Classification |
|
- **K-mer Size:** 6 |
|
- **Input Sequence:** Open Reading Frame(Last 510 nucleotides from end of the sequence) |
|
- **Number of Trainable Parameters:** 89M |
|
- **Max Sequence Length:** 512 |
|
- **Precision Used:** AMP (Automatic Mixed Precision) |
|
|
|
--- |
|
|
|
### **Install Dependencies** |
|
Ensure you have `transformers` and `torch` installed: |
|
```bash |
|
pip install torch transformers |
|
``` |
|
|
|
### **Load Model & Tokenizer** |
|
```python |
|
import torch |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
# Load Model |
|
model_checkpoint = "Genereux-akotenou/BacteriaCDS-DNABERT-K6-89M" |
|
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
``` |
|
|
|
### **Inference Example** |
|
This model works with 6-mer tokenized sequences. You need to convert raw DNA sequences into k-mer format: |
|
```python |
|
def generate_kmer(sequence: str, k: int, overlap: int = 1): |
|
return " ".join([sequence[j:j+k] for j in range(0, len(sequence) - k + 1, overlap)]) |
|
|
|
sequence = "ATGAGAACCAGCCGGAGACCTCCTGCTCGTACATGAAAGGCTCGAGCAGCCGGGCGAGGGCGGTAG" |
|
seq_kmer = generate_kmer(sequence, k=6, overlap=3) |
|
|
|
# Run inference |
|
inputs = tokenizer( |
|
seq_kmer, |
|
return_tensors="pt", |
|
max_length=tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True |
|
) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
``` |
|
|
|
<!-- ### **Citation** |
|
If you use this model in your research, please cite: |
|
```tex |
|
@article{paper title, |
|
title={DNABERT for Bacterial CDS Classification}, |
|
author={Genereux Akotenou, et al.}, |
|
journal={Hugging Face Model Hub}, |
|
year={2024} |
|
} |
|
``` --> |