LGv1_ReadingFrameClassifier / lookingglass.py
adriennehoarfrost's picture
Upload folder using huggingface_hub
2eef7fa verified
"""
LookingGlass - A DNA Language Model
Pure PyTorch implementation of LookingGlass, a pretrained language model for DNA sequences.
Based on AWD-LSTM architecture, originally trained with fastai v1.
Paper: Hoarfrost et al., "Deep learning of a bacterial and archaeal universal language
of life enables transfer learning and illuminates microbial dark matter",
Nature Communications, 2022.
Usage:
from lookingglass import LookingGlass, LookingGlassTokenizer
# Load from HuggingFace Hub
model = LookingGlass.from_pretrained('HoarfrostLab/lookingglass-v1')
tokenizer = LookingGlassTokenizer()
# Or load from local path
model = LookingGlass.from_pretrained('./lookingglass-v1')
inputs = tokenizer(["GATTACA", "ATCGATCG"], return_tensors=True)
embeddings = model.get_embeddings(inputs['input_ids']) # (batch, 104)
"""
import json
import os
import warnings
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, List, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
__version__ = "1.1.0"
def _is_hf_hub_id(path: str) -> bool:
"""Check if path looks like a HuggingFace Hub model ID (e.g., 'user/model')."""
if os.path.exists(path):
return False
return '/' in path and not path.startswith(('.', '/'))
def _download_from_hub(repo_id: str, filename: str) -> str:
"""Download a file from HuggingFace Hub and return the local path."""
if not HF_HUB_AVAILABLE:
raise ImportError(
"huggingface_hub is required to load models from the Hub. "
"Install it with: pip install huggingface_hub"
)
return hf_hub_download(repo_id=repo_id, filename=filename)
__all__ = [
"LookingGlassConfig",
"LookingGlass",
"LookingGlassLM",
"LookingGlassTokenizer",
]
# =============================================================================
# Configuration
# =============================================================================
@dataclass
class LookingGlassConfig:
"""
Configuration for LookingGlass model.
Default values match the original pretrained LookingGlass model.
"""
vocab_size: int = 8
hidden_size: int = 104 # embedding/output size
intermediate_size: int = 1152 # LSTM hidden size
num_hidden_layers: int = 3
pad_token_id: int = 1
bos_token_id: int = 2
eos_token_id: int = 3
bidirectional: bool = False # original LG is unidirectional
output_dropout: float = 0.1
hidden_dropout: float = 0.15
input_dropout: float = 0.25
embed_dropout: float = 0.02
weight_dropout: float = 0.2
tie_weights: bool = True
output_bias: bool = True
model_type: str = "lookingglass"
def to_dict(self) -> Dict:
return asdict(self)
def save_pretrained(self, save_directory: str):
os.makedirs(save_directory, exist_ok=True)
with open(os.path.join(save_directory, "config.json"), 'w') as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassConfig":
if _is_hf_hub_id(pretrained_path):
try:
config_path = _download_from_hub(pretrained_path, "config.json")
except Exception:
return cls()
elif os.path.isdir(pretrained_path):
config_path = os.path.join(pretrained_path, "config.json")
else:
config_path = pretrained_path
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config_dict = json.load(f)
valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
return cls(**{k: v for k, v in config_dict.items() if k in valid_fields})
return cls()
# =============================================================================
# Tokenizer
# =============================================================================
VOCAB = ['xxunk', 'xxpad', 'xxbos', 'xxeos', 'G', 'A', 'C', 'T']
VOCAB_TO_ID = {tok: i for i, tok in enumerate(VOCAB)}
ID_TO_VOCAB = {i: tok for i, tok in enumerate(VOCAB)}
class LookingGlassTokenizer:
"""
Tokenizer for DNA sequences.
Each nucleotide (G, A, C, T) is a single token. By default, adds BOS token
at the start of each sequence (matching original LookingGlass training).
Special tokens:
- xxunk (0): Unknown
- xxpad (1): Padding
- xxbos (2): Beginning of sequence
- xxeos (3): End of sequence
"""
vocab = VOCAB
vocab_to_id = VOCAB_TO_ID
id_to_vocab = ID_TO_VOCAB
def __init__(
self,
add_bos_token: bool = True, # original LG uses BOS
add_eos_token: bool = False, # original LG does not use EOS
padding_side: str = "right",
):
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.padding_side = padding_side
self.unk_token_id = 0
self.pad_token_id = 1
self.bos_token_id = 2
self.eos_token_id = 3
@property
def vocab_size(self) -> int:
return len(self.vocab)
def encode(self, sequence: str, add_special_tokens: bool = True) -> List[int]:
"""Encode a DNA sequence to token IDs."""
tokens = []
if add_special_tokens and self.add_bos_token:
tokens.append(self.bos_token_id)
for char in sequence.upper():
if char in self.vocab_to_id:
tokens.append(self.vocab_to_id[char])
elif char.strip():
tokens.append(self.unk_token_id)
if add_special_tokens and self.add_eos_token:
tokens.append(self.eos_token_id)
return tokens
def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str:
"""Decode token IDs back to DNA sequence."""
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
special_ids = {0, 1, 2, 3}
tokens = []
for tid in token_ids:
if skip_special_tokens and tid in special_ids:
continue
tokens.append(self.id_to_vocab.get(tid, 'xxunk'))
return ''.join(tokens)
def __call__(
self,
sequences: Union[str, List[str]],
padding: Union[bool, str] = False,
max_length: Optional[int] = None,
truncation: bool = False,
return_tensors: Union[bool, str] = False,
return_attention_mask: bool = True,
) -> Dict[str, torch.Tensor]:
"""Tokenize DNA sequence(s)."""
if isinstance(sequences, str):
sequences = [sequences]
single = True
else:
single = False
encoded = [self.encode(seq) for seq in sequences]
if truncation and max_length:
encoded = [e[:max_length] for e in encoded]
# Padding
if padding or len(encoded) > 1:
if padding == 'max_length' and max_length:
pad_len = max_length
else:
pad_len = max(len(e) for e in encoded)
padded = []
masks = []
for e in encoded:
pad_amount = pad_len - len(e)
mask = [1] * len(e) + [0] * pad_amount
if self.padding_side == 'right':
e = e + [self.pad_token_id] * pad_amount
else:
e = [self.pad_token_id] * pad_amount + e
mask = [0] * pad_amount + [1] * len(e)
padded.append(e)
masks.append(mask)
encoded = padded
else:
masks = [[1] * len(e) for e in encoded]
result = {}
if return_tensors in ('pt', True):
result['input_ids'] = torch.tensor(encoded, dtype=torch.long)
if return_attention_mask:
result['attention_mask'] = torch.tensor(masks, dtype=torch.long)
else:
result['input_ids'] = encoded[0] if single else encoded
if return_attention_mask:
result['attention_mask'] = masks[0] if single else masks
return result
def save_pretrained(self, save_directory: str):
os.makedirs(save_directory, exist_ok=True)
with open(os.path.join(save_directory, "vocab.json"), 'w') as f:
json.dump(self.vocab_to_id, f, indent=2)
with open(os.path.join(save_directory, "tokenizer_config.json"), 'w') as f:
json.dump({
"add_bos_token": self.add_bos_token,
"add_eos_token": self.add_eos_token,
"padding_side": self.padding_side,
}, f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassTokenizer":
kwargs = {}
if _is_hf_hub_id(pretrained_path):
try:
config_path = _download_from_hub(pretrained_path, "tokenizer_config.json")
with open(config_path, 'r') as f:
kwargs = json.load(f)
except Exception:
pass
else:
config_path = os.path.join(pretrained_path, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
kwargs = json.load(f)
return cls(**kwargs)
# =============================================================================
# Model Components
# =============================================================================
def _dropout_mask(x: torch.Tensor, size: Tuple[int, ...], p: float) -> torch.Tensor:
"""Create dropout mask with inverted scaling."""
return x.new_empty(*size).bernoulli_(1 - p).div_(1 - p)
class _RNNDropout(nn.Module):
"""Dropout consistent across sequence dimension."""
def __init__(self, p: float = 0.5):
super().__init__()
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.p == 0.:
return x
mask = _dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
return x * mask
class _EmbeddingDropout(nn.Module):
"""Dropout applied to entire embedding rows."""
def __init__(self, embedding: nn.Embedding, p: float):
super().__init__()
self.embedding = embedding
self.p = p
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training and self.p != 0:
mask = _dropout_mask(self.embedding.weight.data,
(self.embedding.weight.size(0), 1), self.p)
masked_weight = self.embedding.weight * mask
else:
masked_weight = self.embedding.weight
padding_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else -1
return F.embedding(x, masked_weight, padding_idx,
self.embedding.max_norm, self.embedding.norm_type,
self.embedding.scale_grad_by_freq, self.embedding.sparse)
class _WeightDropout(nn.Module):
"""DropConnect applied to RNN hidden-to-hidden weights."""
def __init__(self, module: nn.Module, p: float, layer_names='weight_hh_l0'):
super().__init__()
self.module = module
self.p = p
self.layer_names = [layer_names] if isinstance(layer_names, str) else layer_names
for layer in self.layer_names:
w = getattr(self.module, layer)
delattr(self.module, layer)
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
setattr(self.module, layer, w.clone())
if isinstance(self.module, nn.RNNBase):
self.module.flatten_parameters = lambda: None
def _set_weights(self):
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
w = F.dropout(raw_w, p=self.p, training=self.training) if self.training else raw_w.clone()
setattr(self.module, layer, w)
def forward(self, *args):
self._set_weights()
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return self.module(*args)
class _AWDLSTMEncoder(nn.Module):
"""AWD-LSTM encoder backbone."""
_init_range = 0.1
def __init__(self, config: LookingGlassConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_layers = config.num_hidden_layers
self.num_directions = 2 if config.bidirectional else 1
self._batch_size = 1
# Embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.embed_tokens.weight.data.uniform_(-self._init_range, self._init_range)
self.embed_dropout = _EmbeddingDropout(self.embed_tokens, config.embed_dropout)
# LSTM layers
self.layers = nn.ModuleList()
for i in range(config.num_hidden_layers):
input_size = config.hidden_size if i == 0 else config.intermediate_size
output_size = (config.intermediate_size if i != config.num_hidden_layers - 1
else config.hidden_size) // self.num_directions
lstm = nn.LSTM(input_size, output_size, num_layers=1,
batch_first=True, bidirectional=config.bidirectional)
self.layers.append(_WeightDropout(lstm, config.weight_dropout))
# Dropout
self.input_dropout = _RNNDropout(config.input_dropout)
self.hidden_dropout = nn.ModuleList([
_RNNDropout(config.hidden_dropout) for _ in range(config.num_hidden_layers)
])
self._hidden_state = None
self.reset()
def reset(self):
"""Reset LSTM hidden states."""
self._hidden_state = [self._init_hidden(i) for i in range(self.num_layers)]
def _init_hidden(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
nh = (self.intermediate_size if layer_idx != self.num_layers - 1
else self.hidden_size) // self.num_directions
weight = next(self.parameters())
return (weight.new_zeros(self.num_directions, self._batch_size, nh),
weight.new_zeros(self.num_directions, self._batch_size, nh))
def _resize_hidden(self, batch_size: int):
new_hidden = []
for i in range(self.num_layers):
nh = (self.intermediate_size if i != self.num_layers - 1
else self.hidden_size) // self.num_directions
h, c = self._hidden_state[i]
if self._batch_size < batch_size:
h = torch.cat([h, h.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
c = torch.cat([c, c.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1)
elif self._batch_size > batch_size:
h = h[:, :batch_size].contiguous()
c = c[:, :batch_size].contiguous()
new_hidden.append((h, c))
self._hidden_state = new_hidden
self._batch_size = batch_size
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""Returns hidden states for all positions: (batch, seq_len, hidden_size)"""
batch_size, seq_len = input_ids.shape
if batch_size != self._batch_size:
self._resize_hidden(batch_size)
hidden = self.input_dropout(self.embed_dropout(input_ids))
new_hidden = []
for i, (layer, hdp) in enumerate(zip(self.layers, self.hidden_dropout)):
hidden, h = layer(hidden, self._hidden_state[i])
new_hidden.append(h)
if i != self.num_layers - 1:
hidden = hdp(hidden)
self._hidden_state = [(h.detach(), c.detach()) for h, c in new_hidden]
return hidden
class _LMHead(nn.Module):
"""Language modeling head."""
_init_range = 0.1
def __init__(self, config: LookingGlassConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__()
self.output_dropout = _RNNDropout(config.output_dropout)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.output_bias)
self.decoder.weight.data.uniform_(-self._init_range, self._init_range)
if config.output_bias:
self.decoder.bias.data.zero_()
if embed_tokens is not None and config.tie_weights:
self.decoder.weight = embed_tokens.weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.decoder(self.output_dropout(hidden_states))
# =============================================================================
# Models
# =============================================================================
class LookingGlass(nn.Module):
"""
LookingGlass encoder model.
Outputs sequence embeddings for downstream tasks (classification, clustering, etc.).
Uses last-token embedding by default, matching original LookingGlass.
Example:
>>> model = LookingGlass.from_pretrained('lookingglass-v1')
>>> tokenizer = LookingGlassTokenizer()
>>> inputs = tokenizer("GATTACA", return_tensors=True)
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
"""
config_class = LookingGlassConfig
def __init__(self, config: Optional[LookingGlassConfig] = None):
super().__init__()
self.config = config or LookingGlassConfig()
self.encoder = _AWDLSTMEncoder(self.config)
def reset(self):
"""Reset hidden states."""
self.encoder.reset()
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
"""
Forward pass. Returns last-token embeddings.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Embeddings (batch, hidden_size)
"""
return self.get_embeddings(input_ids)
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""
Get sequence embeddings using last-token pooling (original LG method).
Resets hidden state before encoding for deterministic results.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Embeddings (batch, hidden_size)
"""
self.encoder.reset()
hidden = self.encoder(input_ids) # (batch, seq_len, hidden_size)
return hidden[:, -1] # last token
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""
Get hidden states for all positions.
Resets hidden state before encoding for deterministic results.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Hidden states (batch, seq_len, hidden_size)
"""
self.encoder.reset()
return self.encoder(input_ids)
def save_pretrained(self, save_directory: str):
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlass":
config = config or LookingGlassConfig.from_pretrained(pretrained_path)
model = cls(config)
if _is_hf_hub_id(pretrained_path):
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
else:
model_path = os.path.join(pretrained_path, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location='cpu')
# Only load encoder weights
encoder_state_dict = {k: v for k, v in state_dict.items()
if not k.startswith('lm_head.')}
model.load_state_dict(encoder_state_dict, strict=False)
return model
class LookingGlassLM(nn.Module):
"""
LookingGlass with language modeling head.
Full model for next-token prediction. Can also extract embeddings.
Example:
>>> model = LookingGlassLM.from_pretrained('lookingglass-v1')
>>> tokenizer = LookingGlassTokenizer()
>>> inputs = tokenizer("GATTACA", return_tensors=True)
>>> logits = model(inputs['input_ids']) # (1, 8, 8)
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104)
"""
config_class = LookingGlassConfig
def __init__(self, config: Optional[LookingGlassConfig] = None):
super().__init__()
self.config = config or LookingGlassConfig()
self.encoder = _AWDLSTMEncoder(self.config)
self.lm_head = _LMHead(
self.config,
embed_tokens=self.encoder.embed_tokens if self.config.tie_weights else None
)
def reset(self):
"""Reset hidden states."""
self.encoder.reset()
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
"""
Forward pass. Returns logits for next-token prediction.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Logits (batch, seq_len, vocab_size)
"""
hidden = self.encoder(input_ids)
return self.lm_head(hidden)
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""
Get sequence embeddings using last-token pooling.
Resets hidden state before encoding for deterministic results.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Embeddings (batch, hidden_size)
"""
self.encoder.reset()
hidden = self.encoder(input_ids)
return hidden[:, -1]
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor:
"""
Get hidden states for all positions.
Resets hidden state before encoding for deterministic results.
Args:
input_ids: Token indices (batch, seq_len)
Returns:
Hidden states (batch, seq_len, hidden_size)
"""
self.encoder.reset()
return self.encoder(input_ids)
def save_pretrained(self, save_directory: str):
os.makedirs(save_directory, exist_ok=True)
self.config.save_pretrained(save_directory)
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
@classmethod
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlassLM":
config = config or LookingGlassConfig.from_pretrained(pretrained_path)
model = cls(config)
if _is_hf_hub_id(pretrained_path):
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin")
else:
model_path = os.path.join(pretrained_path, "pytorch_model.bin")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
return model
# =============================================================================
# Weight Loading
# =============================================================================
def load_original_weights(model: Union[LookingGlass, LookingGlassLM], weights_path: str) -> None:
"""
Load weights from original fastai-trained LookingGlass checkpoint.
Args:
model: Model to load weights into
weights_path: Path to LookingGlass.pth or LookingGlass_enc.pth
"""
checkpoint = torch.load(weights_path, map_location='cpu')
if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
is_lm_model = isinstance(model, LookingGlassLM)
new_state_dict = {}
for k, v in state_dict.items():
if '.module.weight_hh_l0' in k:
continue
if k.startswith('0.'):
new_k = k[2:]
new_k = new_k.replace('encoder.', 'embed_tokens.')
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
new_k = new_k.replace('rnns.', 'layers.')
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
new_k = new_k.replace('input_dp.', 'input_dropout.')
new_state_dict['encoder.' + new_k] = v
elif k.startswith('1.') and is_lm_model:
new_k = k[2:]
new_k = new_k.replace('output_dp.', 'output_dropout.')
new_state_dict['lm_head.' + new_k] = v
else:
new_k = k.replace('encoder.', 'embed_tokens.')
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.')
new_k = new_k.replace('rnns.', 'layers.')
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.')
new_k = new_k.replace('input_dp.', 'input_dropout.')
new_state_dict['encoder.' + new_k] = v
model.load_state_dict(new_state_dict, strict=False)
def convert_checkpoint(input_path: str, output_dir: str) -> None:
"""Convert original checkpoint to new format."""
config = LookingGlassConfig()
model = LookingGlassLM(config)
load_original_weights(model, input_path)
model.save_pretrained(output_dir)
tokenizer = LookingGlassTokenizer()
tokenizer.save_pretrained(output_dir)
print(f"Saved to {output_dir}")
# =============================================================================
# CLI
# =============================================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='LookingGlass DNA Language Model')
parser.add_argument('--convert', type=str, help='Convert original weights')
parser.add_argument('--output', type=str, default='./lookingglass-v1', help='Output directory')
parser.add_argument('--test', action='store_true', help='Run tests')
args = parser.parse_args()
if args.convert:
convert_checkpoint(args.convert, args.output)
elif args.test:
print("Testing LookingGlass...\n")
tokenizer = LookingGlassTokenizer()
print(f"Vocab: {tokenizer.vocab}")
print(f"BOS token added: {tokenizer.add_bos_token}")
print(f"EOS token added: {tokenizer.add_eos_token}")
inputs = tokenizer("GATTACA", return_tensors=True)
print(f"\nTokenized 'GATTACA': {inputs['input_ids']}")
print(f"Decoded: {tokenizer.decode(inputs['input_ids'][0])}")
config = LookingGlassConfig()
print(f"\nConfig: bidirectional={config.bidirectional}")
# Test LookingGlass (encoder)
encoder = LookingGlass(config)
print(f"\nLookingGlass params: {sum(p.numel() for p in encoder.parameters()):,}")
encoder.eval()
with torch.no_grad():
emb = encoder.get_embeddings(inputs['input_ids'])
print(f"Embeddings shape: {emb.shape}")
# Test LookingGlassLM
lm = LookingGlassLM(config)
print(f"\nLookingGlassLM params: {sum(p.numel() for p in lm.parameters()):,}")
lm.eval()
with torch.no_grad():
logits = lm(inputs['input_ids'])
emb = lm.get_embeddings(inputs['input_ids'])
print(f"Logits shape: {logits.shape}")
print(f"Embeddings shape: {emb.shape}")
print("\nAll tests passed!")
else:
parser.print_help()