Upload folder using huggingface_hub
Browse files- README.md +88 -0
- config.json +22 -0
- lookingglass.py +780 -0
- lookingglass_classifier.py +258 -0
- pytorch_model.bin +3 -0
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
tags:
|
| 5 |
+
- biology
|
| 6 |
+
- dna
|
| 7 |
+
- genomics
|
| 8 |
+
- metagenomics
|
| 9 |
+
- classifier
|
| 10 |
+
- awd-lstm
|
| 11 |
+
- transfer-learning
|
| 12 |
+
license: mit
|
| 13 |
+
pipeline_tag: text-classification
|
| 14 |
+
library_name: pytorch
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# LookingGlass Functional Classifier
|
| 18 |
+
|
| 19 |
+
Classifies DNA reads into one of 1274 experimentally-validated functional annotations with 81.5% accuracy.
|
| 20 |
+
|
| 21 |
+
This is a **pure PyTorch implementation** fine-tuned from the LookingGlass base model.
|
| 22 |
+
|
| 23 |
+
## Links
|
| 24 |
+
|
| 25 |
+
- **Paper**: [Deep learning of a bacterial and archaeal universal language of life enables transfer learning and illuminates microbial dark matter](https://doi.org/10.1038/s41467-022-30070-8) (Nature Communications, 2022)
|
| 26 |
+
- **GitHub**: [ahoarfrost/LookingGlass](https://github.com/ahoarfrost/LookingGlass)
|
| 27 |
+
- **Base Model**: [HoarfrostLab/lookingglass-v1](https://huggingface.co/HoarfrostLab/lookingglass-v1)
|
| 28 |
+
|
| 29 |
+
## Citation
|
| 30 |
+
|
| 31 |
+
```bibtex
|
| 32 |
+
@article{hoarfrost2022deep,
|
| 33 |
+
title={Deep learning of a bacterial and archaeal universal language of life
|
| 34 |
+
enables transfer learning and illuminates microbial dark matter},
|
| 35 |
+
author={Hoarfrost, Adrienne and Aptekmann, Ariel and Farfanuk, Gaetan and Bromberg, Yana},
|
| 36 |
+
journal={Nature Communications},
|
| 37 |
+
volume={13},
|
| 38 |
+
number={1},
|
| 39 |
+
pages={2606},
|
| 40 |
+
year={2022},
|
| 41 |
+
publisher={Nature Publishing Group}
|
| 42 |
+
}
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Model
|
| 46 |
+
|
| 47 |
+
| | |
|
| 48 |
+
|---|---|
|
| 49 |
+
| Architecture | LookingGlass encoder + classification head |
|
| 50 |
+
| Encoder | AWD-LSTM (3-layer, unidirectional) |
|
| 51 |
+
| Classes | 1274 functional annotation classes |
|
| 52 |
+
| Parameters | ~17M |
|
| 53 |
+
|
| 54 |
+
## Installation
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
pip install torch
|
| 58 |
+
git clone https://huggingface.co/HoarfrostLab/LGv1_FunctionalClassifier
|
| 59 |
+
cd LGv1_FunctionalClassifier
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Usage
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer
|
| 66 |
+
|
| 67 |
+
model = LookingGlassClassifier.from_pretrained('.')
|
| 68 |
+
tokenizer = LookingGlassTokenizer()
|
| 69 |
+
model.eval()
|
| 70 |
+
|
| 71 |
+
inputs = tokenizer(["GATTACA", "ATCGATCGATCG"], return_tensors=True)
|
| 72 |
+
|
| 73 |
+
# Get predictions
|
| 74 |
+
predictions = model.predict(inputs['input_ids'])
|
| 75 |
+
print(predictions) # tensor([class_idx, class_idx])
|
| 76 |
+
|
| 77 |
+
# Get probabilities
|
| 78 |
+
probs = model.predict_proba(inputs['input_ids'])
|
| 79 |
+
print(probs.shape) # torch.Size([2, 1274])
|
| 80 |
+
|
| 81 |
+
# Get raw logits
|
| 82 |
+
logits = model(inputs['input_ids'])
|
| 83 |
+
print(logits.shape) # torch.Size([2, 1274])
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## License
|
| 87 |
+
|
| 88 |
+
MIT License
|
config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 8,
|
| 3 |
+
"hidden_size": 104,
|
| 4 |
+
"intermediate_size": 1152,
|
| 5 |
+
"num_hidden_layers": 3,
|
| 6 |
+
"pad_token_id": 1,
|
| 7 |
+
"bos_token_id": 2,
|
| 8 |
+
"eos_token_id": 3,
|
| 9 |
+
"bidirectional": false,
|
| 10 |
+
"output_dropout": 0.1,
|
| 11 |
+
"hidden_dropout": 0.15,
|
| 12 |
+
"input_dropout": 0.25,
|
| 13 |
+
"embed_dropout": 0.02,
|
| 14 |
+
"weight_dropout": 0.2,
|
| 15 |
+
"tie_weights": true,
|
| 16 |
+
"output_bias": true,
|
| 17 |
+
"model_type": "lookingglass",
|
| 18 |
+
"num_classes": 1274,
|
| 19 |
+
"classifier_hidden": 50,
|
| 20 |
+
"classifier_dropout": 0.0,
|
| 21 |
+
"class_names": []
|
| 22 |
+
}
|
lookingglass.py
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LookingGlass - A DNA Language Model
|
| 3 |
+
|
| 4 |
+
Pure PyTorch implementation of LookingGlass, a pretrained language model for DNA sequences.
|
| 5 |
+
Based on AWD-LSTM architecture, originally trained with fastai v1.
|
| 6 |
+
|
| 7 |
+
Paper: Hoarfrost et al., "Deep learning of a bacterial and archaeal universal language
|
| 8 |
+
of life enables transfer learning and illuminates microbial dark matter",
|
| 9 |
+
Nature Communications, 2022.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
from lookingglass import LookingGlass, LookingGlassTokenizer
|
| 13 |
+
|
| 14 |
+
# Load from HuggingFace Hub
|
| 15 |
+
model = LookingGlass.from_pretrained('HoarfrostLab/lookingglass-v1')
|
| 16 |
+
tokenizer = LookingGlassTokenizer()
|
| 17 |
+
|
| 18 |
+
# Or load from local path
|
| 19 |
+
model = LookingGlass.from_pretrained('./lookingglass-v1')
|
| 20 |
+
|
| 21 |
+
inputs = tokenizer(["GATTACA", "ATCGATCG"], return_tensors=True)
|
| 22 |
+
embeddings = model.get_embeddings(inputs['input_ids']) # (batch, 104)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import warnings
|
| 28 |
+
from dataclasses import dataclass, asdict
|
| 29 |
+
from typing import Optional, Tuple, List, Dict, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
+
HF_HUB_AVAILABLE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
HF_HUB_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
__version__ = "1.1.0"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _is_hf_hub_id(path: str) -> bool:
|
| 46 |
+
"""Check if path looks like a HuggingFace Hub model ID (e.g., 'user/model')."""
|
| 47 |
+
if os.path.exists(path):
|
| 48 |
+
return False
|
| 49 |
+
return '/' in path and not path.startswith(('.', '/'))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _download_from_hub(repo_id: str, filename: str) -> str:
|
| 53 |
+
"""Download a file from HuggingFace Hub and return the local path."""
|
| 54 |
+
if not HF_HUB_AVAILABLE:
|
| 55 |
+
raise ImportError(
|
| 56 |
+
"huggingface_hub is required to load models from the Hub. "
|
| 57 |
+
"Install it with: pip install huggingface_hub"
|
| 58 |
+
)
|
| 59 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
| 60 |
+
__all__ = [
|
| 61 |
+
"LookingGlassConfig",
|
| 62 |
+
"LookingGlass",
|
| 63 |
+
"LookingGlassLM",
|
| 64 |
+
"LookingGlassTokenizer",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# =============================================================================
|
| 69 |
+
# Configuration
|
| 70 |
+
# =============================================================================
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class LookingGlassConfig:
|
| 74 |
+
"""
|
| 75 |
+
Configuration for LookingGlass model.
|
| 76 |
+
|
| 77 |
+
Default values match the original pretrained LookingGlass model.
|
| 78 |
+
"""
|
| 79 |
+
vocab_size: int = 8
|
| 80 |
+
hidden_size: int = 104 # embedding/output size
|
| 81 |
+
intermediate_size: int = 1152 # LSTM hidden size
|
| 82 |
+
num_hidden_layers: int = 3
|
| 83 |
+
pad_token_id: int = 1
|
| 84 |
+
bos_token_id: int = 2
|
| 85 |
+
eos_token_id: int = 3
|
| 86 |
+
bidirectional: bool = False # original LG is unidirectional
|
| 87 |
+
output_dropout: float = 0.1
|
| 88 |
+
hidden_dropout: float = 0.15
|
| 89 |
+
input_dropout: float = 0.25
|
| 90 |
+
embed_dropout: float = 0.02
|
| 91 |
+
weight_dropout: float = 0.2
|
| 92 |
+
tie_weights: bool = True
|
| 93 |
+
output_bias: bool = True
|
| 94 |
+
model_type: str = "lookingglass"
|
| 95 |
+
|
| 96 |
+
def to_dict(self) -> Dict:
|
| 97 |
+
return asdict(self)
|
| 98 |
+
|
| 99 |
+
def save_pretrained(self, save_directory: str):
|
| 100 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 101 |
+
with open(os.path.join(save_directory, "config.json"), 'w') as f:
|
| 102 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassConfig":
|
| 106 |
+
if _is_hf_hub_id(pretrained_path):
|
| 107 |
+
try:
|
| 108 |
+
config_path = _download_from_hub(pretrained_path, "config.json")
|
| 109 |
+
except Exception:
|
| 110 |
+
return cls()
|
| 111 |
+
elif os.path.isdir(pretrained_path):
|
| 112 |
+
config_path = os.path.join(pretrained_path, "config.json")
|
| 113 |
+
else:
|
| 114 |
+
config_path = pretrained_path
|
| 115 |
+
|
| 116 |
+
if os.path.exists(config_path):
|
| 117 |
+
with open(config_path, 'r') as f:
|
| 118 |
+
config_dict = json.load(f)
|
| 119 |
+
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
| 120 |
+
return cls(**{k: v for k, v in config_dict.items() if k in valid_fields})
|
| 121 |
+
return cls()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# =============================================================================
|
| 125 |
+
# Tokenizer
|
| 126 |
+
# =============================================================================
|
| 127 |
+
|
| 128 |
+
VOCAB = ['xxunk', 'xxpad', 'xxbos', 'xxeos', 'G', 'A', 'C', 'T']
|
| 129 |
+
VOCAB_TO_ID = {tok: i for i, tok in enumerate(VOCAB)}
|
| 130 |
+
ID_TO_VOCAB = {i: tok for i, tok in enumerate(VOCAB)}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class LookingGlassTokenizer:
|
| 134 |
+
"""
|
| 135 |
+
Tokenizer for DNA sequences.
|
| 136 |
+
|
| 137 |
+
Each nucleotide (G, A, C, T) is a single token. By default, adds BOS token
|
| 138 |
+
at the start of each sequence (matching original LookingGlass training).
|
| 139 |
+
|
| 140 |
+
Special tokens:
|
| 141 |
+
- xxunk (0): Unknown
|
| 142 |
+
- xxpad (1): Padding
|
| 143 |
+
- xxbos (2): Beginning of sequence
|
| 144 |
+
- xxeos (3): End of sequence
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
vocab = VOCAB
|
| 148 |
+
vocab_to_id = VOCAB_TO_ID
|
| 149 |
+
id_to_vocab = ID_TO_VOCAB
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
add_bos_token: bool = True, # original LG uses BOS
|
| 154 |
+
add_eos_token: bool = False, # original LG does not use EOS
|
| 155 |
+
padding_side: str = "right",
|
| 156 |
+
):
|
| 157 |
+
self.add_bos_token = add_bos_token
|
| 158 |
+
self.add_eos_token = add_eos_token
|
| 159 |
+
self.padding_side = padding_side
|
| 160 |
+
|
| 161 |
+
self.unk_token_id = 0
|
| 162 |
+
self.pad_token_id = 1
|
| 163 |
+
self.bos_token_id = 2
|
| 164 |
+
self.eos_token_id = 3
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def vocab_size(self) -> int:
|
| 168 |
+
return len(self.vocab)
|
| 169 |
+
|
| 170 |
+
def encode(self, sequence: str, add_special_tokens: bool = True) -> List[int]:
|
| 171 |
+
"""Encode a DNA sequence to token IDs."""
|
| 172 |
+
tokens = []
|
| 173 |
+
|
| 174 |
+
if add_special_tokens and self.add_bos_token:
|
| 175 |
+
tokens.append(self.bos_token_id)
|
| 176 |
+
|
| 177 |
+
for char in sequence.upper():
|
| 178 |
+
if char in self.vocab_to_id:
|
| 179 |
+
tokens.append(self.vocab_to_id[char])
|
| 180 |
+
elif char.strip():
|
| 181 |
+
tokens.append(self.unk_token_id)
|
| 182 |
+
|
| 183 |
+
if add_special_tokens and self.add_eos_token:
|
| 184 |
+
tokens.append(self.eos_token_id)
|
| 185 |
+
|
| 186 |
+
return tokens
|
| 187 |
+
|
| 188 |
+
def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str:
|
| 189 |
+
"""Decode token IDs back to DNA sequence."""
|
| 190 |
+
if isinstance(token_ids, torch.Tensor):
|
| 191 |
+
token_ids = token_ids.tolist()
|
| 192 |
+
|
| 193 |
+
special_ids = {0, 1, 2, 3}
|
| 194 |
+
tokens = []
|
| 195 |
+
for tid in token_ids:
|
| 196 |
+
if skip_special_tokens and tid in special_ids:
|
| 197 |
+
continue
|
| 198 |
+
tokens.append(self.id_to_vocab.get(tid, 'xxunk'))
|
| 199 |
+
return ''.join(tokens)
|
| 200 |
+
|
| 201 |
+
def __call__(
|
| 202 |
+
self,
|
| 203 |
+
sequences: Union[str, List[str]],
|
| 204 |
+
padding: Union[bool, str] = False,
|
| 205 |
+
max_length: Optional[int] = None,
|
| 206 |
+
truncation: bool = False,
|
| 207 |
+
return_tensors: Union[bool, str] = False,
|
| 208 |
+
return_attention_mask: bool = True,
|
| 209 |
+
) -> Dict[str, torch.Tensor]:
|
| 210 |
+
"""Tokenize DNA sequence(s)."""
|
| 211 |
+
if isinstance(sequences, str):
|
| 212 |
+
sequences = [sequences]
|
| 213 |
+
single = True
|
| 214 |
+
else:
|
| 215 |
+
single = False
|
| 216 |
+
|
| 217 |
+
encoded = [self.encode(seq) for seq in sequences]
|
| 218 |
+
|
| 219 |
+
if truncation and max_length:
|
| 220 |
+
encoded = [e[:max_length] for e in encoded]
|
| 221 |
+
|
| 222 |
+
# Padding
|
| 223 |
+
if padding or len(encoded) > 1:
|
| 224 |
+
if padding == 'max_length' and max_length:
|
| 225 |
+
pad_len = max_length
|
| 226 |
+
else:
|
| 227 |
+
pad_len = max(len(e) for e in encoded)
|
| 228 |
+
|
| 229 |
+
padded = []
|
| 230 |
+
masks = []
|
| 231 |
+
for e in encoded:
|
| 232 |
+
pad_amount = pad_len - len(e)
|
| 233 |
+
mask = [1] * len(e) + [0] * pad_amount
|
| 234 |
+
if self.padding_side == 'right':
|
| 235 |
+
e = e + [self.pad_token_id] * pad_amount
|
| 236 |
+
else:
|
| 237 |
+
e = [self.pad_token_id] * pad_amount + e
|
| 238 |
+
mask = [0] * pad_amount + [1] * len(e)
|
| 239 |
+
padded.append(e)
|
| 240 |
+
masks.append(mask)
|
| 241 |
+
encoded = padded
|
| 242 |
+
else:
|
| 243 |
+
masks = [[1] * len(e) for e in encoded]
|
| 244 |
+
|
| 245 |
+
result = {}
|
| 246 |
+
if return_tensors in ('pt', True):
|
| 247 |
+
result['input_ids'] = torch.tensor(encoded, dtype=torch.long)
|
| 248 |
+
if return_attention_mask:
|
| 249 |
+
result['attention_mask'] = torch.tensor(masks, dtype=torch.long)
|
| 250 |
+
else:
|
| 251 |
+
result['input_ids'] = encoded[0] if single else encoded
|
| 252 |
+
if return_attention_mask:
|
| 253 |
+
result['attention_mask'] = masks[0] if single else masks
|
| 254 |
+
|
| 255 |
+
return result
|
| 256 |
+
|
| 257 |
+
def save_pretrained(self, save_directory: str):
|
| 258 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 259 |
+
with open(os.path.join(save_directory, "vocab.json"), 'w') as f:
|
| 260 |
+
json.dump(self.vocab_to_id, f, indent=2)
|
| 261 |
+
with open(os.path.join(save_directory, "tokenizer_config.json"), 'w') as f:
|
| 262 |
+
json.dump({
|
| 263 |
+
"add_bos_token": self.add_bos_token,
|
| 264 |
+
"add_eos_token": self.add_eos_token,
|
| 265 |
+
"padding_side": self.padding_side,
|
| 266 |
+
}, f, indent=2)
|
| 267 |
+
|
| 268 |
+
@classmethod
|
| 269 |
+
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassTokenizer":
|
| 270 |
+
kwargs = {}
|
| 271 |
+
if _is_hf_hub_id(pretrained_path):
|
| 272 |
+
try:
|
| 273 |
+
config_path = _download_from_hub(pretrained_path, "tokenizer_config.json")
|
| 274 |
+
with open(config_path, 'r') as f:
|
| 275 |
+
kwargs = json.load(f)
|
| 276 |
+
except Exception:
|
| 277 |
+
pass
|
| 278 |
+
else:
|
| 279 |
+
config_path = os.path.join(pretrained_path, "tokenizer_config.json")
|
| 280 |
+
if os.path.exists(config_path):
|
| 281 |
+
with open(config_path, 'r') as f:
|
| 282 |
+
kwargs = json.load(f)
|
| 283 |
+
return cls(**kwargs)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# =============================================================================
|
| 287 |
+
# Model Components
|
| 288 |
+
# =============================================================================
|
| 289 |
+
|
| 290 |
+
def _dropout_mask(x: torch.Tensor, size: Tuple[int, ...], p: float) -> torch.Tensor:
|
| 291 |
+
"""Create dropout mask with inverted scaling."""
|
| 292 |
+
return x.new_empty(*size).bernoulli_(1 - p).div_(1 - p)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class _RNNDropout(nn.Module):
|
| 296 |
+
"""Dropout consistent across sequence dimension."""
|
| 297 |
+
|
| 298 |
+
def __init__(self, p: float = 0.5):
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.p = p
|
| 301 |
+
|
| 302 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
if not self.training or self.p == 0.:
|
| 304 |
+
return x
|
| 305 |
+
mask = _dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
|
| 306 |
+
return x * mask
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class _EmbeddingDropout(nn.Module):
|
| 310 |
+
"""Dropout applied to entire embedding rows."""
|
| 311 |
+
|
| 312 |
+
def __init__(self, embedding: nn.Embedding, p: float):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.embedding = embedding
|
| 315 |
+
self.p = p
|
| 316 |
+
|
| 317 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 318 |
+
if self.training and self.p != 0:
|
| 319 |
+
mask = _dropout_mask(self.embedding.weight.data,
|
| 320 |
+
(self.embedding.weight.size(0), 1), self.p)
|
| 321 |
+
masked_weight = self.embedding.weight * mask
|
| 322 |
+
else:
|
| 323 |
+
masked_weight = self.embedding.weight
|
| 324 |
+
|
| 325 |
+
padding_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else -1
|
| 326 |
+
return F.embedding(x, masked_weight, padding_idx,
|
| 327 |
+
self.embedding.max_norm, self.embedding.norm_type,
|
| 328 |
+
self.embedding.scale_grad_by_freq, self.embedding.sparse)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class _WeightDropout(nn.Module):
|
| 332 |
+
"""DropConnect applied to RNN hidden-to-hidden weights."""
|
| 333 |
+
|
| 334 |
+
def __init__(self, module: nn.Module, p: float, layer_names='weight_hh_l0'):
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.module = module
|
| 337 |
+
self.p = p
|
| 338 |
+
self.layer_names = [layer_names] if isinstance(layer_names, str) else layer_names
|
| 339 |
+
|
| 340 |
+
for layer in self.layer_names:
|
| 341 |
+
w = getattr(self.module, layer)
|
| 342 |
+
delattr(self.module, layer)
|
| 343 |
+
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
|
| 344 |
+
setattr(self.module, layer, w.clone())
|
| 345 |
+
|
| 346 |
+
if isinstance(self.module, nn.RNNBase):
|
| 347 |
+
self.module.flatten_parameters = lambda: None
|
| 348 |
+
|
| 349 |
+
def _set_weights(self):
|
| 350 |
+
for layer in self.layer_names:
|
| 351 |
+
raw_w = getattr(self, f'{layer}_raw')
|
| 352 |
+
w = F.dropout(raw_w, p=self.p, training=self.training) if self.training else raw_w.clone()
|
| 353 |
+
setattr(self.module, layer, w)
|
| 354 |
+
|
| 355 |
+
def forward(self, *args):
|
| 356 |
+
self._set_weights()
|
| 357 |
+
with warnings.catch_warnings():
|
| 358 |
+
warnings.simplefilter("ignore", category=UserWarning)
|
| 359 |
+
return self.module(*args)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class _AWDLSTMEncoder(nn.Module):
|
| 363 |
+
"""AWD-LSTM encoder backbone."""
|
| 364 |
+
|
| 365 |
+
_init_range = 0.1
|
| 366 |
+
|
| 367 |
+
def __init__(self, config: LookingGlassConfig):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.config = config
|
| 370 |
+
self.hidden_size = config.hidden_size
|
| 371 |
+
self.intermediate_size = config.intermediate_size
|
| 372 |
+
self.num_layers = config.num_hidden_layers
|
| 373 |
+
self.num_directions = 2 if config.bidirectional else 1
|
| 374 |
+
self._batch_size = 1
|
| 375 |
+
|
| 376 |
+
# Embedding
|
| 377 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
|
| 378 |
+
padding_idx=config.pad_token_id)
|
| 379 |
+
self.embed_tokens.weight.data.uniform_(-self._init_range, self._init_range)
|
| 380 |
+
self.embed_dropout = _EmbeddingDropout(self.embed_tokens, config.embed_dropout)
|
| 381 |
+
|
| 382 |
+
# LSTM layers
|
| 383 |
+
self.layers = nn.ModuleList()
|
| 384 |
+
for i in range(config.num_hidden_layers):
|
| 385 |
+
input_size = config.hidden_size if i == 0 else config.intermediate_size
|
| 386 |
+
output_size = (config.intermediate_size if i != config.num_hidden_layers - 1
|
| 387 |
+
else config.hidden_size) // self.num_directions
|
| 388 |
+
lstm = nn.LSTM(input_size, output_size, num_layers=1,
|
| 389 |
+
batch_first=True, bidirectional=config.bidirectional)
|
| 390 |
+
self.layers.append(_WeightDropout(lstm, config.weight_dropout))
|
| 391 |
+
|
| 392 |
+
# Dropout
|
| 393 |
+
self.input_dropout = _RNNDropout(config.input_dropout)
|
| 394 |
+
self.hidden_dropout = nn.ModuleList([
|
| 395 |
+
_RNNDropout(config.hidden_dropout) for _ in range(config.num_hidden_layers)
|
| 396 |
+
])
|
| 397 |
+
|
| 398 |
+
self._hidden_state = None
|
| 399 |
+
self.reset()
|
| 400 |
+
|
| 401 |
+
def reset(self):
|
| 402 |
+
"""Reset LSTM hidden states."""
|
| 403 |
+
self._hidden_state = [self._init_hidden(i) for i in range(self.num_layers)]
|
| 404 |
+
|
| 405 |
+
def _init_hidden(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 406 |
+
nh = (self.intermediate_size if layer_idx != self.num_layers - 1
|
| 407 |
+
else self.hidden_size) // self.num_directions
|
| 408 |
+
weight = next(self.parameters())
|
| 409 |
+
return (weight.new_zeros(self.num_directions, self._batch_size, nh),
|
| 410 |
+
weight.new_zeros(self.num_directions, self._batch_size, nh))
|
| 411 |
+
|
| 412 |
+
def _resize_hidden(self, batch_size: int):
|
| 413 |
+
new_hidden = []
|
| 414 |
+
for i in range(self.num_layers):
|
| 415 |
+
nh = (self.intermediate_size if i != self.num_layers - 1
|
| 416 |
+
else self.hidden_size) // self.num_directions
|
| 417 |
+
h, c = self._hidden_state[i]
|
| 418 |
+
|
| 419 |
+
if self._batch_size < batch_size:
|
| 420 |
+
h = torch.cat([h, h.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
|
| 421 |
+
c = torch.cat([c, c.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
|
| 422 |
+
elif self._batch_size > batch_size:
|
| 423 |
+
h = h[:, :batch_size].contiguous()
|
| 424 |
+
c = c[:, :batch_size].contiguous()
|
| 425 |
+
new_hidden.append((h, c))
|
| 426 |
+
|
| 427 |
+
self._hidden_state = new_hidden
|
| 428 |
+
self._batch_size = batch_size
|
| 429 |
+
|
| 430 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 431 |
+
"""Returns hidden states for all positions: (batch, seq_len, hidden_size)"""
|
| 432 |
+
batch_size, seq_len = input_ids.shape
|
| 433 |
+
|
| 434 |
+
if batch_size != self._batch_size:
|
| 435 |
+
self._resize_hidden(batch_size)
|
| 436 |
+
|
| 437 |
+
hidden = self.input_dropout(self.embed_dropout(input_ids))
|
| 438 |
+
|
| 439 |
+
new_hidden = []
|
| 440 |
+
for i, (layer, hdp) in enumerate(zip(self.layers, self.hidden_dropout)):
|
| 441 |
+
hidden, h = layer(hidden, self._hidden_state[i])
|
| 442 |
+
new_hidden.append(h)
|
| 443 |
+
if i != self.num_layers - 1:
|
| 444 |
+
hidden = hdp(hidden)
|
| 445 |
+
|
| 446 |
+
self._hidden_state = [(h.detach(), c.detach()) for h, c in new_hidden]
|
| 447 |
+
return hidden
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class _LMHead(nn.Module):
|
| 451 |
+
"""Language modeling head."""
|
| 452 |
+
|
| 453 |
+
_init_range = 0.1
|
| 454 |
+
|
| 455 |
+
def __init__(self, config: LookingGlassConfig, embed_tokens: Optional[nn.Embedding] = None):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.output_dropout = _RNNDropout(config.output_dropout)
|
| 458 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.output_bias)
|
| 459 |
+
self.decoder.weight.data.uniform_(-self._init_range, self._init_range)
|
| 460 |
+
|
| 461 |
+
if config.output_bias:
|
| 462 |
+
self.decoder.bias.data.zero_()
|
| 463 |
+
|
| 464 |
+
if embed_tokens is not None and config.tie_weights:
|
| 465 |
+
self.decoder.weight = embed_tokens.weight
|
| 466 |
+
|
| 467 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 468 |
+
return self.decoder(self.output_dropout(hidden_states))
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# =============================================================================
|
| 472 |
+
# Models
|
| 473 |
+
# =============================================================================
|
| 474 |
+
|
| 475 |
+
class LookingGlass(nn.Module):
|
| 476 |
+
"""
|
| 477 |
+
LookingGlass encoder model.
|
| 478 |
+
|
| 479 |
+
Outputs sequence embeddings for downstream tasks (classification, clustering, etc.).
|
| 480 |
+
Uses last-token embedding by default, matching original LookingGlass.
|
| 481 |
+
|
| 482 |
+
Example:
|
| 483 |
+
>>> model = LookingGlass.from_pretrained('lookingglass-v1')
|
| 484 |
+
>>> tokenizer = LookingGlassTokenizer()
|
| 485 |
+
>>> inputs = tokenizer("GATTACA", return_tensors=True)
|
| 486 |
+
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
config_class = LookingGlassConfig
|
| 490 |
+
|
| 491 |
+
def __init__(self, config: Optional[LookingGlassConfig] = None):
|
| 492 |
+
super().__init__()
|
| 493 |
+
self.config = config or LookingGlassConfig()
|
| 494 |
+
self.encoder = _AWDLSTMEncoder(self.config)
|
| 495 |
+
|
| 496 |
+
def reset(self):
|
| 497 |
+
"""Reset hidden states."""
|
| 498 |
+
self.encoder.reset()
|
| 499 |
+
|
| 500 |
+
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
|
| 501 |
+
"""
|
| 502 |
+
Forward pass. Returns last-token embeddings.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
input_ids: Token indices (batch, seq_len)
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
Embeddings (batch, hidden_size)
|
| 509 |
+
"""
|
| 510 |
+
return self.get_embeddings(input_ids)
|
| 511 |
+
|
| 512 |
+
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 513 |
+
"""
|
| 514 |
+
Get sequence embeddings using last-token pooling (original LG method).
|
| 515 |
+
|
| 516 |
+
Resets hidden state before encoding for deterministic results.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
input_ids: Token indices (batch, seq_len)
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
Embeddings (batch, hidden_size)
|
| 523 |
+
"""
|
| 524 |
+
self.encoder.reset()
|
| 525 |
+
hidden = self.encoder(input_ids) # (batch, seq_len, hidden_size)
|
| 526 |
+
return hidden[:, -1] # last token
|
| 527 |
+
|
| 528 |
+
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 529 |
+
"""
|
| 530 |
+
Get hidden states for all positions.
|
| 531 |
+
|
| 532 |
+
Resets hidden state before encoding for deterministic results.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
input_ids: Token indices (batch, seq_len)
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
Hidden states (batch, seq_len, hidden_size)
|
| 539 |
+
"""
|
| 540 |
+
self.encoder.reset()
|
| 541 |
+
return self.encoder(input_ids)
|
| 542 |
+
|
| 543 |
+
def save_pretrained(self, save_directory: str):
|
| 544 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 545 |
+
self.config.save_pretrained(save_directory)
|
| 546 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
| 547 |
+
|
| 548 |
+
@classmethod
|
| 549 |
+
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlass":
|
| 550 |
+
config = config or LookingGlassConfig.from_pretrained(pretrained_path)
|
| 551 |
+
model = cls(config)
|
| 552 |
+
|
| 553 |
+
if _is_hf_hub_id(pretrained_path):
|
| 554 |
+
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
|
| 555 |
+
else:
|
| 556 |
+
model_path = os.path.join(pretrained_path, "pytorch_model.bin")
|
| 557 |
+
|
| 558 |
+
if os.path.exists(model_path):
|
| 559 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 560 |
+
# Only load encoder weights
|
| 561 |
+
encoder_state_dict = {k: v for k, v in state_dict.items()
|
| 562 |
+
if not k.startswith('lm_head.')}
|
| 563 |
+
model.load_state_dict(encoder_state_dict, strict=False)
|
| 564 |
+
|
| 565 |
+
return model
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class LookingGlassLM(nn.Module):
|
| 569 |
+
"""
|
| 570 |
+
LookingGlass with language modeling head.
|
| 571 |
+
|
| 572 |
+
Full model for next-token prediction. Can also extract embeddings.
|
| 573 |
+
|
| 574 |
+
Example:
|
| 575 |
+
>>> model = LookingGlassLM.from_pretrained('lookingglass-v1')
|
| 576 |
+
>>> tokenizer = LookingGlassTokenizer()
|
| 577 |
+
>>> inputs = tokenizer("GATTACA", return_tensors=True)
|
| 578 |
+
>>> logits = model(inputs['input_ids']) # (1, 8, 8)
|
| 579 |
+
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
|
| 580 |
+
"""
|
| 581 |
+
|
| 582 |
+
config_class = LookingGlassConfig
|
| 583 |
+
|
| 584 |
+
def __init__(self, config: Optional[LookingGlassConfig] = None):
|
| 585 |
+
super().__init__()
|
| 586 |
+
self.config = config or LookingGlassConfig()
|
| 587 |
+
self.encoder = _AWDLSTMEncoder(self.config)
|
| 588 |
+
self.lm_head = _LMHead(
|
| 589 |
+
self.config,
|
| 590 |
+
embed_tokens=self.encoder.embed_tokens if self.config.tie_weights else None
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
def reset(self):
|
| 594 |
+
"""Reset hidden states."""
|
| 595 |
+
self.encoder.reset()
|
| 596 |
+
|
| 597 |
+
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
|
| 598 |
+
"""
|
| 599 |
+
Forward pass. Returns logits for next-token prediction.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
input_ids: Token indices (batch, seq_len)
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Logits (batch, seq_len, vocab_size)
|
| 606 |
+
"""
|
| 607 |
+
hidden = self.encoder(input_ids)
|
| 608 |
+
return self.lm_head(hidden)
|
| 609 |
+
|
| 610 |
+
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 611 |
+
"""
|
| 612 |
+
Get sequence embeddings using last-token pooling.
|
| 613 |
+
|
| 614 |
+
Resets hidden state before encoding for deterministic results.
|
| 615 |
+
|
| 616 |
+
Args:
|
| 617 |
+
input_ids: Token indices (batch, seq_len)
|
| 618 |
+
|
| 619 |
+
Returns:
|
| 620 |
+
Embeddings (batch, hidden_size)
|
| 621 |
+
"""
|
| 622 |
+
self.encoder.reset()
|
| 623 |
+
hidden = self.encoder(input_ids)
|
| 624 |
+
return hidden[:, -1]
|
| 625 |
+
|
| 626 |
+
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 627 |
+
"""
|
| 628 |
+
Get hidden states for all positions.
|
| 629 |
+
|
| 630 |
+
Resets hidden state before encoding for deterministic results.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
input_ids: Token indices (batch, seq_len)
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
Hidden states (batch, seq_len, hidden_size)
|
| 637 |
+
"""
|
| 638 |
+
self.encoder.reset()
|
| 639 |
+
return self.encoder(input_ids)
|
| 640 |
+
|
| 641 |
+
def save_pretrained(self, save_directory: str):
|
| 642 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 643 |
+
self.config.save_pretrained(save_directory)
|
| 644 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
| 645 |
+
|
| 646 |
+
@classmethod
|
| 647 |
+
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlassLM":
|
| 648 |
+
config = config or LookingGlassConfig.from_pretrained(pretrained_path)
|
| 649 |
+
model = cls(config)
|
| 650 |
+
|
| 651 |
+
if _is_hf_hub_id(pretrained_path):
|
| 652 |
+
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
|
| 653 |
+
else:
|
| 654 |
+
model_path = os.path.join(pretrained_path, "pytorch_model.bin")
|
| 655 |
+
|
| 656 |
+
if os.path.exists(model_path):
|
| 657 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 658 |
+
model.load_state_dict(state_dict, strict=False)
|
| 659 |
+
|
| 660 |
+
return model
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# =============================================================================
|
| 664 |
+
# Weight Loading
|
| 665 |
+
# =============================================================================
|
| 666 |
+
|
| 667 |
+
def load_original_weights(model: Union[LookingGlass, LookingGlassLM], weights_path: str) -> None:
|
| 668 |
+
"""
|
| 669 |
+
Load weights from original fastai-trained LookingGlass checkpoint.
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
model: Model to load weights into
|
| 673 |
+
weights_path: Path to LookingGlass.pth or LookingGlass_enc.pth
|
| 674 |
+
"""
|
| 675 |
+
checkpoint = torch.load(weights_path, map_location='cpu')
|
| 676 |
+
|
| 677 |
+
if 'model' in checkpoint:
|
| 678 |
+
state_dict = checkpoint['model']
|
| 679 |
+
else:
|
| 680 |
+
state_dict = checkpoint
|
| 681 |
+
|
| 682 |
+
is_lm_model = isinstance(model, LookingGlassLM)
|
| 683 |
+
|
| 684 |
+
new_state_dict = {}
|
| 685 |
+
for k, v in state_dict.items():
|
| 686 |
+
if '.module.weight_hh_l0' in k:
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
if k.startswith('0.'):
|
| 690 |
+
new_k = k[2:]
|
| 691 |
+
new_k = new_k.replace('encoder.', 'embed_tokens.')
|
| 692 |
+
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
|
| 693 |
+
new_k = new_k.replace('rnns.', 'layers.')
|
| 694 |
+
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
|
| 695 |
+
new_k = new_k.replace('input_dp.', 'input_dropout.')
|
| 696 |
+
new_state_dict['encoder.' + new_k] = v
|
| 697 |
+
|
| 698 |
+
elif k.startswith('1.') and is_lm_model:
|
| 699 |
+
new_k = k[2:]
|
| 700 |
+
new_k = new_k.replace('output_dp.', 'output_dropout.')
|
| 701 |
+
new_state_dict['lm_head.' + new_k] = v
|
| 702 |
+
|
| 703 |
+
else:
|
| 704 |
+
new_k = k.replace('encoder.', 'embed_tokens.')
|
| 705 |
+
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
|
| 706 |
+
new_k = new_k.replace('rnns.', 'layers.')
|
| 707 |
+
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
|
| 708 |
+
new_k = new_k.replace('input_dp.', 'input_dropout.')
|
| 709 |
+
new_state_dict['encoder.' + new_k] = v
|
| 710 |
+
|
| 711 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def convert_checkpoint(input_path: str, output_dir: str) -> None:
|
| 715 |
+
"""Convert original checkpoint to new format."""
|
| 716 |
+
config = LookingGlassConfig()
|
| 717 |
+
model = LookingGlassLM(config)
|
| 718 |
+
load_original_weights(model, input_path)
|
| 719 |
+
model.save_pretrained(output_dir)
|
| 720 |
+
|
| 721 |
+
tokenizer = LookingGlassTokenizer()
|
| 722 |
+
tokenizer.save_pretrained(output_dir)
|
| 723 |
+
print(f"Saved to {output_dir}")
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
# =============================================================================
|
| 727 |
+
# CLI
|
| 728 |
+
# =============================================================================
|
| 729 |
+
|
| 730 |
+
if __name__ == '__main__':
|
| 731 |
+
import argparse
|
| 732 |
+
|
| 733 |
+
parser = argparse.ArgumentParser(description='LookingGlass DNA Language Model')
|
| 734 |
+
parser.add_argument('--convert', type=str, help='Convert original weights')
|
| 735 |
+
parser.add_argument('--output', type=str, default='./lookingglass-v1', help='Output directory')
|
| 736 |
+
parser.add_argument('--test', action='store_true', help='Run tests')
|
| 737 |
+
args = parser.parse_args()
|
| 738 |
+
|
| 739 |
+
if args.convert:
|
| 740 |
+
convert_checkpoint(args.convert, args.output)
|
| 741 |
+
|
| 742 |
+
elif args.test:
|
| 743 |
+
print("Testing LookingGlass...\n")
|
| 744 |
+
|
| 745 |
+
tokenizer = LookingGlassTokenizer()
|
| 746 |
+
print(f"Vocab: {tokenizer.vocab}")
|
| 747 |
+
print(f"BOS token added: {tokenizer.add_bos_token}")
|
| 748 |
+
print(f"EOS token added: {tokenizer.add_eos_token}")
|
| 749 |
+
|
| 750 |
+
inputs = tokenizer("GATTACA", return_tensors=True)
|
| 751 |
+
print(f"\nTokenized 'GATTACA': {inputs['input_ids']}")
|
| 752 |
+
print(f"Decoded: {tokenizer.decode(inputs['input_ids'][0])}")
|
| 753 |
+
|
| 754 |
+
config = LookingGlassConfig()
|
| 755 |
+
print(f"\nConfig: bidirectional={config.bidirectional}")
|
| 756 |
+
|
| 757 |
+
# Test LookingGlass (encoder)
|
| 758 |
+
encoder = LookingGlass(config)
|
| 759 |
+
print(f"\nLookingGlass params: {sum(p.numel() for p in encoder.parameters()):,}")
|
| 760 |
+
|
| 761 |
+
encoder.eval()
|
| 762 |
+
with torch.no_grad():
|
| 763 |
+
emb = encoder.get_embeddings(inputs['input_ids'])
|
| 764 |
+
print(f"Embeddings shape: {emb.shape}")
|
| 765 |
+
|
| 766 |
+
# Test LookingGlassLM
|
| 767 |
+
lm = LookingGlassLM(config)
|
| 768 |
+
print(f"\nLookingGlassLM params: {sum(p.numel() for p in lm.parameters()):,}")
|
| 769 |
+
|
| 770 |
+
lm.eval()
|
| 771 |
+
with torch.no_grad():
|
| 772 |
+
logits = lm(inputs['input_ids'])
|
| 773 |
+
emb = lm.get_embeddings(inputs['input_ids'])
|
| 774 |
+
print(f"Logits shape: {logits.shape}")
|
| 775 |
+
print(f"Embeddings shape: {emb.shape}")
|
| 776 |
+
|
| 777 |
+
print("\nAll tests passed!")
|
| 778 |
+
|
| 779 |
+
else:
|
| 780 |
+
parser.print_help()
|
lookingglass_classifier.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LookingGlass Classifiers - Fine-tuned DNA sequence classifiers
|
| 3 |
+
|
| 4 |
+
Pure PyTorch implementation of LookingGlass classifiers from the paper.
|
| 5 |
+
Uses LookingGlass encoder with classification head.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from lookingglass_classifier import LookingGlassClassifier, LookingGlassTokenizer
|
| 9 |
+
|
| 10 |
+
model = LookingGlassClassifier.from_pretrained('.')
|
| 11 |
+
tokenizer = LookingGlassTokenizer()
|
| 12 |
+
|
| 13 |
+
inputs = tokenizer(["GATTACA"], return_tensors=True)
|
| 14 |
+
logits = model(inputs['input_ids']) # (batch, num_classes)
|
| 15 |
+
predictions = logits.argmax(dim=-1)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from dataclasses import dataclass, asdict, field
|
| 21 |
+
from typing import Optional, List
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
|
| 26 |
+
from lookingglass import (
|
| 27 |
+
LookingGlassConfig,
|
| 28 |
+
LookingGlassTokenizer,
|
| 29 |
+
_AWDLSTMEncoder,
|
| 30 |
+
_is_hf_hub_id,
|
| 31 |
+
_download_from_hub,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
__version__ = "1.1.0"
|
| 35 |
+
__all__ = ["LookingGlassClassifierConfig", "LookingGlassClassifier", "LookingGlassTokenizer"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class LookingGlassClassifierConfig(LookingGlassConfig):
|
| 40 |
+
"""Configuration for LookingGlass classifier."""
|
| 41 |
+
num_classes: int = 2
|
| 42 |
+
classifier_hidden: int = 50
|
| 43 |
+
classifier_dropout: float = 0.0
|
| 44 |
+
class_names: List[str] = field(default_factory=list)
|
| 45 |
+
|
| 46 |
+
def save_pretrained(self, save_directory: str):
|
| 47 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 48 |
+
with open(os.path.join(save_directory, "config.json"), 'w') as f:
|
| 49 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassClassifierConfig":
|
| 53 |
+
if _is_hf_hub_id(pretrained_path):
|
| 54 |
+
try:
|
| 55 |
+
config_path = _download_from_hub(pretrained_path, "config.json")
|
| 56 |
+
except Exception:
|
| 57 |
+
return cls()
|
| 58 |
+
elif os.path.isdir(pretrained_path):
|
| 59 |
+
config_path = os.path.join(pretrained_path, "config.json")
|
| 60 |
+
else:
|
| 61 |
+
config_path = pretrained_path
|
| 62 |
+
|
| 63 |
+
if os.path.exists(config_path):
|
| 64 |
+
with open(config_path, 'r') as f:
|
| 65 |
+
config_dict = json.load(f)
|
| 66 |
+
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
|
| 67 |
+
return cls(**{k: v for k, v in config_dict.items() if k in valid_fields})
|
| 68 |
+
return cls()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LookingGlassClassifier(nn.Module):
|
| 72 |
+
"""
|
| 73 |
+
LookingGlass with classification head.
|
| 74 |
+
|
| 75 |
+
Uses concat pooling (max + mean + last) followed by classification layers.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
>>> model = LookingGlassClassifier.from_pretrained('.')
|
| 79 |
+
>>> tokenizer = LookingGlassTokenizer()
|
| 80 |
+
>>> inputs = tokenizer("GATTACA", return_tensors=True)
|
| 81 |
+
>>> logits = model(inputs['input_ids']) # (1, num_classes)
|
| 82 |
+
>>> prediction = logits.argmax(dim=-1)
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, config: Optional[LookingGlassClassifierConfig] = None):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.config = config or LookingGlassClassifierConfig()
|
| 88 |
+
self.encoder = _AWDLSTMEncoder(self.config)
|
| 89 |
+
|
| 90 |
+
# Concat pooling: max + mean + last = 3 * hidden_size
|
| 91 |
+
pooled_size = 3 * self.config.hidden_size
|
| 92 |
+
|
| 93 |
+
# Classification head: BatchNorm -> Linear -> ReLU -> BatchNorm -> Linear
|
| 94 |
+
self.classifier = nn.Sequential(
|
| 95 |
+
nn.BatchNorm1d(pooled_size),
|
| 96 |
+
nn.Dropout(self.config.classifier_dropout),
|
| 97 |
+
nn.Linear(pooled_size, self.config.classifier_hidden),
|
| 98 |
+
nn.ReLU(),
|
| 99 |
+
nn.BatchNorm1d(self.config.classifier_hidden),
|
| 100 |
+
nn.Dropout(self.config.classifier_dropout),
|
| 101 |
+
nn.Linear(self.config.classifier_hidden, self.config.num_classes),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 105 |
+
"""
|
| 106 |
+
Forward pass returning classification logits.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
input_ids: Token indices (batch, seq_len)
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Logits (batch, num_classes)
|
| 113 |
+
"""
|
| 114 |
+
self.encoder.reset()
|
| 115 |
+
hidden = self.encoder(input_ids) # (batch, seq_len, hidden_size)
|
| 116 |
+
|
| 117 |
+
# Concat pooling: max, mean, last
|
| 118 |
+
max_pool = hidden.max(dim=1).values
|
| 119 |
+
mean_pool = hidden.mean(dim=1)
|
| 120 |
+
last_pool = hidden[:, -1]
|
| 121 |
+
pooled = torch.cat([max_pool, mean_pool, last_pool], dim=-1)
|
| 122 |
+
|
| 123 |
+
return self.classifier(pooled)
|
| 124 |
+
|
| 125 |
+
def predict(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 126 |
+
"""Return predicted class indices."""
|
| 127 |
+
logits = self.forward(input_ids)
|
| 128 |
+
return logits.argmax(dim=-1)
|
| 129 |
+
|
| 130 |
+
def predict_proba(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 131 |
+
"""Return class probabilities."""
|
| 132 |
+
logits = self.forward(input_ids)
|
| 133 |
+
return torch.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| 136 |
+
"""Get sequence embeddings (last token) from encoder."""
|
| 137 |
+
self.encoder.reset()
|
| 138 |
+
hidden = self.encoder(input_ids)
|
| 139 |
+
return hidden[:, -1]
|
| 140 |
+
|
| 141 |
+
def save_pretrained(self, save_directory: str):
|
| 142 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 143 |
+
self.config.save_pretrained(save_directory)
|
| 144 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def from_pretrained(
|
| 148 |
+
cls, pretrained_path: str, config: Optional[LookingGlassClassifierConfig] = None
|
| 149 |
+
) -> "LookingGlassClassifier":
|
| 150 |
+
config = config or LookingGlassClassifierConfig.from_pretrained(pretrained_path)
|
| 151 |
+
model = cls(config)
|
| 152 |
+
|
| 153 |
+
if _is_hf_hub_id(pretrained_path):
|
| 154 |
+
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
|
| 155 |
+
else:
|
| 156 |
+
model_path = os.path.join(pretrained_path, "pytorch_model.bin")
|
| 157 |
+
|
| 158 |
+
if os.path.exists(model_path):
|
| 159 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
| 160 |
+
model.load_state_dict(state_dict, strict=False)
|
| 161 |
+
|
| 162 |
+
return model
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def convert_classifier_weights(
|
| 166 |
+
original_path: str,
|
| 167 |
+
output_dir: str,
|
| 168 |
+
num_classes: int,
|
| 169 |
+
class_names: Optional[List[str]] = None,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Convert original fastai classifier weights to pure PyTorch format.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
original_path: Path to original .pth file
|
| 176 |
+
output_dir: Output directory for converted model
|
| 177 |
+
num_classes: Number of output classes
|
| 178 |
+
class_names: Optional list of class names
|
| 179 |
+
"""
|
| 180 |
+
print(f"Loading weights from {original_path}...")
|
| 181 |
+
original = torch.load(original_path, map_location='cpu')
|
| 182 |
+
if 'model' in original:
|
| 183 |
+
original = original['model']
|
| 184 |
+
|
| 185 |
+
# Create config
|
| 186 |
+
config = LookingGlassClassifierConfig(
|
| 187 |
+
num_classes=num_classes,
|
| 188 |
+
classifier_hidden=50,
|
| 189 |
+
class_names=class_names or [],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Create model
|
| 193 |
+
model = LookingGlassClassifier(config)
|
| 194 |
+
|
| 195 |
+
# Map weights
|
| 196 |
+
new_state = {}
|
| 197 |
+
|
| 198 |
+
# Encoder weights
|
| 199 |
+
weight_map = {
|
| 200 |
+
'0.module.encoder.weight': 'encoder.embed_tokens.weight',
|
| 201 |
+
'0.module.encoder_dp.emb.weight': 'encoder.embed_dropout.embedding.weight',
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
for i in range(3):
|
| 205 |
+
weight_map.update({
|
| 206 |
+
f'0.module.rnns.{i}.weight_hh_l0_raw': f'encoder.layers.{i}.weight_hh_l0_raw',
|
| 207 |
+
f'0.module.rnns.{i}.module.weight_ih_l0': f'encoder.layers.{i}.module.weight_ih_l0',
|
| 208 |
+
f'0.module.rnns.{i}.module.weight_hh_l0': f'encoder.layers.{i}.module.weight_hh_l0',
|
| 209 |
+
f'0.module.rnns.{i}.module.bias_ih_l0': f'encoder.layers.{i}.module.bias_ih_l0',
|
| 210 |
+
f'0.module.rnns.{i}.module.bias_hh_l0': f'encoder.layers.{i}.module.bias_hh_l0',
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
# Classifier head weights
|
| 214 |
+
# Original: 1.layers.{0,2,4,6} -> our Sequential indices
|
| 215 |
+
classifier_map = {
|
| 216 |
+
'1.layers.0.weight': 'classifier.0.weight',
|
| 217 |
+
'1.layers.0.bias': 'classifier.0.bias',
|
| 218 |
+
'1.layers.0.running_mean': 'classifier.0.running_mean',
|
| 219 |
+
'1.layers.0.running_var': 'classifier.0.running_var',
|
| 220 |
+
'1.layers.0.num_batches_tracked': 'classifier.0.num_batches_tracked',
|
| 221 |
+
'1.layers.2.weight': 'classifier.2.weight',
|
| 222 |
+
'1.layers.2.bias': 'classifier.2.bias',
|
| 223 |
+
'1.layers.4.weight': 'classifier.4.weight',
|
| 224 |
+
'1.layers.4.bias': 'classifier.4.bias',
|
| 225 |
+
'1.layers.4.running_mean': 'classifier.4.running_mean',
|
| 226 |
+
'1.layers.4.running_var': 'classifier.4.running_var',
|
| 227 |
+
'1.layers.4.num_batches_tracked': 'classifier.4.num_batches_tracked',
|
| 228 |
+
'1.layers.6.weight': 'classifier.6.weight',
|
| 229 |
+
'1.layers.6.bias': 'classifier.6.bias',
|
| 230 |
+
}
|
| 231 |
+
weight_map.update(classifier_map)
|
| 232 |
+
|
| 233 |
+
for old_key, new_key in weight_map.items():
|
| 234 |
+
if old_key in original:
|
| 235 |
+
new_state[new_key] = original[old_key]
|
| 236 |
+
|
| 237 |
+
# Load and save
|
| 238 |
+
model.load_state_dict(new_state, strict=False)
|
| 239 |
+
|
| 240 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 241 |
+
config.save_pretrained(output_dir)
|
| 242 |
+
torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
| 243 |
+
|
| 244 |
+
print(f"Saved to {output_dir}")
|
| 245 |
+
return model
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
import argparse
|
| 250 |
+
|
| 251 |
+
parser = argparse.ArgumentParser(description="Convert LookingGlass classifier weights")
|
| 252 |
+
parser.add_argument("--input", required=True, help="Path to original .pth file")
|
| 253 |
+
parser.add_argument("--output", required=True, help="Output directory")
|
| 254 |
+
parser.add_argument("--num-classes", type=int, required=True, help="Number of classes")
|
| 255 |
+
parser.add_argument("--class-names", nargs="+", help="Class names")
|
| 256 |
+
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
convert_classifier_weights(args.input, args.output, args.num_classes, args.class_names)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6104341f1668118250861bdcc9dd17b26af44de9b84e27e2a1bb3ee0f74f67fb
|
| 3 |
+
size 68121438
|