|
--- |
|
datasets: |
|
- genbio-ai/rna-downstream-tasks |
|
base_model: |
|
- genbio-ai/rnafm-1.6b |
|
--- |
|
LoRA fine-tuned checkpoint for splice site prediction. |
|
|
|
## How to Use |
|
### Download model |
|
```python |
|
from huggingface_hub import snapshot_download |
|
from pathlib import Path |
|
|
|
model_name = "genbio-ai/rnafm-1.6b-csp-acceptor-ckpt" |
|
genbio_models_path = Path.home().joinpath('genbio_models', model_name) |
|
genbio_models_path.mkdir(parents=True, exist_ok=True) |
|
snapshot_download(repo_id=model_name, local_dir=genbio_models_path) |
|
``` |
|
### Load model for inference |
|
```python |
|
import torch |
|
from genbio_finetune.tasks import SequenceClassification |
|
|
|
ckpt_path = genbio_models_path.joinpath('model.ckpt') |
|
model = SequenceClassification.load_from_checkpoint(ckpt_path, strict_loading=False).eval() |
|
|
|
collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]}) |
|
logits = model(collated_batch) |
|
print(logits) |
|
print(torch.argmax(logits, dim=-1)) |
|
``` |