|
|
""" |
|
|
LookingGlass - A DNA Language Model |
|
|
|
|
|
Pure PyTorch implementation of LookingGlass, a pretrained language model for DNA sequences. |
|
|
Based on AWD-LSTM architecture, originally trained with fastai v1. |
|
|
|
|
|
Paper: Hoarfrost et al., "Deep learning of a bacterial and archaeal universal language |
|
|
of life enables transfer learning and illuminates microbial dark matter", |
|
|
Nature Communications, 2022. |
|
|
|
|
|
Usage: |
|
|
from lookingglass import LookingGlass, LookingGlassTokenizer |
|
|
|
|
|
# Load from HuggingFace Hub |
|
|
model = LookingGlass.from_pretrained('HoarfrostLab/lookingglass-v1') |
|
|
tokenizer = LookingGlassTokenizer() |
|
|
|
|
|
# Or load from local path |
|
|
model = LookingGlass.from_pretrained('./lookingglass-v1') |
|
|
|
|
|
inputs = tokenizer(["GATTACA", "ATCGATCG"], return_tensors=True) |
|
|
embeddings = model.get_embeddings(inputs['input_ids']) # (batch, 104) |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
import warnings |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import Optional, Tuple, List, Dict, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
|
|
|
|
|
|
__version__ = "1.1.0" |
|
|
|
|
|
|
|
|
def _is_hf_hub_id(path: str) -> bool: |
|
|
"""Check if path looks like a HuggingFace Hub model ID (e.g., 'user/model').""" |
|
|
if os.path.exists(path): |
|
|
return False |
|
|
return '/' in path and not path.startswith(('.', '/')) |
|
|
|
|
|
|
|
|
def _download_from_hub(repo_id: str, filename: str) -> str: |
|
|
"""Download a file from HuggingFace Hub and return the local path.""" |
|
|
if not HF_HUB_AVAILABLE: |
|
|
raise ImportError( |
|
|
"huggingface_hub is required to load models from the Hub. " |
|
|
"Install it with: pip install huggingface_hub" |
|
|
) |
|
|
return hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
__all__ = [ |
|
|
"LookingGlassConfig", |
|
|
"LookingGlass", |
|
|
"LookingGlassLM", |
|
|
"LookingGlassTokenizer", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LookingGlassConfig: |
|
|
""" |
|
|
Configuration for LookingGlass model. |
|
|
|
|
|
Default values match the original pretrained LookingGlass model. |
|
|
""" |
|
|
vocab_size: int = 8 |
|
|
hidden_size: int = 104 |
|
|
intermediate_size: int = 1152 |
|
|
num_hidden_layers: int = 3 |
|
|
pad_token_id: int = 1 |
|
|
bos_token_id: int = 2 |
|
|
eos_token_id: int = 3 |
|
|
bidirectional: bool = False |
|
|
output_dropout: float = 0.1 |
|
|
hidden_dropout: float = 0.15 |
|
|
input_dropout: float = 0.25 |
|
|
embed_dropout: float = 0.02 |
|
|
weight_dropout: float = 0.2 |
|
|
tie_weights: bool = True |
|
|
output_bias: bool = True |
|
|
model_type: str = "lookingglass" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return asdict(self) |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
with open(os.path.join(save_directory, "config.json"), 'w') as f: |
|
|
json.dump(self.to_dict(), f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassConfig": |
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
try: |
|
|
config_path = _download_from_hub(pretrained_path, "config.json") |
|
|
except Exception: |
|
|
return cls() |
|
|
elif os.path.isdir(pretrained_path): |
|
|
config_path = os.path.join(pretrained_path, "config.json") |
|
|
else: |
|
|
config_path = pretrained_path |
|
|
|
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
config_dict = json.load(f) |
|
|
valid_fields = {f.name for f in cls.__dataclass_fields__.values()} |
|
|
return cls(**{k: v for k, v in config_dict.items() if k in valid_fields}) |
|
|
return cls() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VOCAB = ['xxunk', 'xxpad', 'xxbos', 'xxeos', 'G', 'A', 'C', 'T'] |
|
|
VOCAB_TO_ID = {tok: i for i, tok in enumerate(VOCAB)} |
|
|
ID_TO_VOCAB = {i: tok for i, tok in enumerate(VOCAB)} |
|
|
|
|
|
|
|
|
class LookingGlassTokenizer: |
|
|
""" |
|
|
Tokenizer for DNA sequences. |
|
|
|
|
|
Each nucleotide (G, A, C, T) is a single token. By default, adds BOS token |
|
|
at the start of each sequence (matching original LookingGlass training). |
|
|
|
|
|
Special tokens: |
|
|
- xxunk (0): Unknown |
|
|
- xxpad (1): Padding |
|
|
- xxbos (2): Beginning of sequence |
|
|
- xxeos (3): End of sequence |
|
|
""" |
|
|
|
|
|
vocab = VOCAB |
|
|
vocab_to_id = VOCAB_TO_ID |
|
|
id_to_vocab = ID_TO_VOCAB |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
add_bos_token: bool = True, |
|
|
add_eos_token: bool = False, |
|
|
padding_side: str = "right", |
|
|
): |
|
|
self.add_bos_token = add_bos_token |
|
|
self.add_eos_token = add_eos_token |
|
|
self.padding_side = padding_side |
|
|
|
|
|
self.unk_token_id = 0 |
|
|
self.pad_token_id = 1 |
|
|
self.bos_token_id = 2 |
|
|
self.eos_token_id = 3 |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self.vocab) |
|
|
|
|
|
def encode(self, sequence: str, add_special_tokens: bool = True) -> List[int]: |
|
|
"""Encode a DNA sequence to token IDs.""" |
|
|
tokens = [] |
|
|
|
|
|
if add_special_tokens and self.add_bos_token: |
|
|
tokens.append(self.bos_token_id) |
|
|
|
|
|
for char in sequence.upper(): |
|
|
if char in self.vocab_to_id: |
|
|
tokens.append(self.vocab_to_id[char]) |
|
|
elif char.strip(): |
|
|
tokens.append(self.unk_token_id) |
|
|
|
|
|
if add_special_tokens and self.add_eos_token: |
|
|
tokens.append(self.eos_token_id) |
|
|
|
|
|
return tokens |
|
|
|
|
|
def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str: |
|
|
"""Decode token IDs back to DNA sequence.""" |
|
|
if isinstance(token_ids, torch.Tensor): |
|
|
token_ids = token_ids.tolist() |
|
|
|
|
|
special_ids = {0, 1, 2, 3} |
|
|
tokens = [] |
|
|
for tid in token_ids: |
|
|
if skip_special_tokens and tid in special_ids: |
|
|
continue |
|
|
tokens.append(self.id_to_vocab.get(tid, 'xxunk')) |
|
|
return ''.join(tokens) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
sequences: Union[str, List[str]], |
|
|
padding: Union[bool, str] = False, |
|
|
max_length: Optional[int] = None, |
|
|
truncation: bool = False, |
|
|
return_tensors: Union[bool, str] = False, |
|
|
return_attention_mask: bool = True, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Tokenize DNA sequence(s).""" |
|
|
if isinstance(sequences, str): |
|
|
sequences = [sequences] |
|
|
single = True |
|
|
else: |
|
|
single = False |
|
|
|
|
|
encoded = [self.encode(seq) for seq in sequences] |
|
|
|
|
|
if truncation and max_length: |
|
|
encoded = [e[:max_length] for e in encoded] |
|
|
|
|
|
|
|
|
if padding or len(encoded) > 1: |
|
|
if padding == 'max_length' and max_length: |
|
|
pad_len = max_length |
|
|
else: |
|
|
pad_len = max(len(e) for e in encoded) |
|
|
|
|
|
padded = [] |
|
|
masks = [] |
|
|
for e in encoded: |
|
|
pad_amount = pad_len - len(e) |
|
|
mask = [1] * len(e) + [0] * pad_amount |
|
|
if self.padding_side == 'right': |
|
|
e = e + [self.pad_token_id] * pad_amount |
|
|
else: |
|
|
e = [self.pad_token_id] * pad_amount + e |
|
|
mask = [0] * pad_amount + [1] * len(e) |
|
|
padded.append(e) |
|
|
masks.append(mask) |
|
|
encoded = padded |
|
|
else: |
|
|
masks = [[1] * len(e) for e in encoded] |
|
|
|
|
|
result = {} |
|
|
if return_tensors in ('pt', True): |
|
|
result['input_ids'] = torch.tensor(encoded, dtype=torch.long) |
|
|
if return_attention_mask: |
|
|
result['attention_mask'] = torch.tensor(masks, dtype=torch.long) |
|
|
else: |
|
|
result['input_ids'] = encoded[0] if single else encoded |
|
|
if return_attention_mask: |
|
|
result['attention_mask'] = masks[0] if single else masks |
|
|
|
|
|
return result |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
with open(os.path.join(save_directory, "vocab.json"), 'w') as f: |
|
|
json.dump(self.vocab_to_id, f, indent=2) |
|
|
with open(os.path.join(save_directory, "tokenizer_config.json"), 'w') as f: |
|
|
json.dump({ |
|
|
"add_bos_token": self.add_bos_token, |
|
|
"add_eos_token": self.add_eos_token, |
|
|
"padding_side": self.padding_side, |
|
|
}, f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_path: str) -> "LookingGlassTokenizer": |
|
|
kwargs = {} |
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
try: |
|
|
config_path = _download_from_hub(pretrained_path, "tokenizer_config.json") |
|
|
with open(config_path, 'r') as f: |
|
|
kwargs = json.load(f) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
config_path = os.path.join(pretrained_path, "tokenizer_config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
kwargs = json.load(f) |
|
|
return cls(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dropout_mask(x: torch.Tensor, size: Tuple[int, ...], p: float) -> torch.Tensor: |
|
|
"""Create dropout mask with inverted scaling.""" |
|
|
return x.new_empty(*size).bernoulli_(1 - p).div_(1 - p) |
|
|
|
|
|
|
|
|
class _RNNDropout(nn.Module): |
|
|
"""Dropout consistent across sequence dimension.""" |
|
|
|
|
|
def __init__(self, p: float = 0.5): |
|
|
super().__init__() |
|
|
self.p = p |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if not self.training or self.p == 0.: |
|
|
return x |
|
|
mask = _dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p) |
|
|
return x * mask |
|
|
|
|
|
|
|
|
class _EmbeddingDropout(nn.Module): |
|
|
"""Dropout applied to entire embedding rows.""" |
|
|
|
|
|
def __init__(self, embedding: nn.Embedding, p: float): |
|
|
super().__init__() |
|
|
self.embedding = embedding |
|
|
self.p = p |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.training and self.p != 0: |
|
|
mask = _dropout_mask(self.embedding.weight.data, |
|
|
(self.embedding.weight.size(0), 1), self.p) |
|
|
masked_weight = self.embedding.weight * mask |
|
|
else: |
|
|
masked_weight = self.embedding.weight |
|
|
|
|
|
padding_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else -1 |
|
|
return F.embedding(x, masked_weight, padding_idx, |
|
|
self.embedding.max_norm, self.embedding.norm_type, |
|
|
self.embedding.scale_grad_by_freq, self.embedding.sparse) |
|
|
|
|
|
|
|
|
class _WeightDropout(nn.Module): |
|
|
"""DropConnect applied to RNN hidden-to-hidden weights.""" |
|
|
|
|
|
def __init__(self, module: nn.Module, p: float, layer_names='weight_hh_l0'): |
|
|
super().__init__() |
|
|
self.module = module |
|
|
self.p = p |
|
|
self.layer_names = [layer_names] if isinstance(layer_names, str) else layer_names |
|
|
|
|
|
for layer in self.layer_names: |
|
|
w = getattr(self.module, layer) |
|
|
delattr(self.module, layer) |
|
|
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data)) |
|
|
setattr(self.module, layer, w.clone()) |
|
|
|
|
|
if isinstance(self.module, nn.RNNBase): |
|
|
self.module.flatten_parameters = lambda: None |
|
|
|
|
|
def _set_weights(self): |
|
|
for layer in self.layer_names: |
|
|
raw_w = getattr(self, f'{layer}_raw') |
|
|
w = F.dropout(raw_w, p=self.p, training=self.training) if self.training else raw_w.clone() |
|
|
setattr(self.module, layer, w) |
|
|
|
|
|
def forward(self, *args): |
|
|
self._set_weights() |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", category=UserWarning) |
|
|
return self.module(*args) |
|
|
|
|
|
|
|
|
class _AWDLSTMEncoder(nn.Module): |
|
|
"""AWD-LSTM encoder backbone.""" |
|
|
|
|
|
_init_range = 0.1 |
|
|
|
|
|
def __init__(self, config: LookingGlassConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.num_layers = config.num_hidden_layers |
|
|
self.num_directions = 2 if config.bidirectional else 1 |
|
|
self._batch_size = 1 |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, |
|
|
padding_idx=config.pad_token_id) |
|
|
self.embed_tokens.weight.data.uniform_(-self._init_range, self._init_range) |
|
|
self.embed_dropout = _EmbeddingDropout(self.embed_tokens, config.embed_dropout) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList() |
|
|
for i in range(config.num_hidden_layers): |
|
|
input_size = config.hidden_size if i == 0 else config.intermediate_size |
|
|
output_size = (config.intermediate_size if i != config.num_hidden_layers - 1 |
|
|
else config.hidden_size) // self.num_directions |
|
|
lstm = nn.LSTM(input_size, output_size, num_layers=1, |
|
|
batch_first=True, bidirectional=config.bidirectional) |
|
|
self.layers.append(_WeightDropout(lstm, config.weight_dropout)) |
|
|
|
|
|
|
|
|
self.input_dropout = _RNNDropout(config.input_dropout) |
|
|
self.hidden_dropout = nn.ModuleList([ |
|
|
_RNNDropout(config.hidden_dropout) for _ in range(config.num_hidden_layers) |
|
|
]) |
|
|
|
|
|
self._hidden_state = None |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
"""Reset LSTM hidden states.""" |
|
|
self._hidden_state = [self._init_hidden(i) for i in range(self.num_layers)] |
|
|
|
|
|
def _init_hidden(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
nh = (self.intermediate_size if layer_idx != self.num_layers - 1 |
|
|
else self.hidden_size) // self.num_directions |
|
|
weight = next(self.parameters()) |
|
|
return (weight.new_zeros(self.num_directions, self._batch_size, nh), |
|
|
weight.new_zeros(self.num_directions, self._batch_size, nh)) |
|
|
|
|
|
def _resize_hidden(self, batch_size: int): |
|
|
new_hidden = [] |
|
|
for i in range(self.num_layers): |
|
|
nh = (self.intermediate_size if i != self.num_layers - 1 |
|
|
else self.hidden_size) // self.num_directions |
|
|
h, c = self._hidden_state[i] |
|
|
|
|
|
if self._batch_size < batch_size: |
|
|
h = torch.cat([h, h.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1) |
|
|
c = torch.cat([c, c.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1) |
|
|
elif self._batch_size > batch_size: |
|
|
h = h[:, :batch_size].contiguous() |
|
|
c = c[:, :batch_size].contiguous() |
|
|
new_hidden.append((h, c)) |
|
|
|
|
|
self._hidden_state = new_hidden |
|
|
self._batch_size = batch_size |
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
"""Returns hidden states for all positions: (batch, seq_len, hidden_size)""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
if batch_size != self._batch_size: |
|
|
self._resize_hidden(batch_size) |
|
|
|
|
|
hidden = self.input_dropout(self.embed_dropout(input_ids)) |
|
|
|
|
|
new_hidden = [] |
|
|
for i, (layer, hdp) in enumerate(zip(self.layers, self.hidden_dropout)): |
|
|
hidden, h = layer(hidden, self._hidden_state[i]) |
|
|
new_hidden.append(h) |
|
|
if i != self.num_layers - 1: |
|
|
hidden = hdp(hidden) |
|
|
|
|
|
self._hidden_state = [(h.detach(), c.detach()) for h, c in new_hidden] |
|
|
return hidden |
|
|
|
|
|
|
|
|
class _LMHead(nn.Module): |
|
|
"""Language modeling head.""" |
|
|
|
|
|
_init_range = 0.1 |
|
|
|
|
|
def __init__(self, config: LookingGlassConfig, embed_tokens: Optional[nn.Embedding] = None): |
|
|
super().__init__() |
|
|
self.output_dropout = _RNNDropout(config.output_dropout) |
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.output_bias) |
|
|
self.decoder.weight.data.uniform_(-self._init_range, self._init_range) |
|
|
|
|
|
if config.output_bias: |
|
|
self.decoder.bias.data.zero_() |
|
|
|
|
|
if embed_tokens is not None and config.tie_weights: |
|
|
self.decoder.weight = embed_tokens.weight |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
return self.decoder(self.output_dropout(hidden_states)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LookingGlass(nn.Module): |
|
|
""" |
|
|
LookingGlass encoder model. |
|
|
|
|
|
Outputs sequence embeddings for downstream tasks (classification, clustering, etc.). |
|
|
Uses last-token embedding by default, matching original LookingGlass. |
|
|
|
|
|
Example: |
|
|
>>> model = LookingGlass.from_pretrained('lookingglass-v1') |
|
|
>>> tokenizer = LookingGlassTokenizer() |
|
|
>>> inputs = tokenizer("GATTACA", return_tensors=True) |
|
|
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104) |
|
|
""" |
|
|
|
|
|
config_class = LookingGlassConfig |
|
|
|
|
|
def __init__(self, config: Optional[LookingGlassConfig] = None): |
|
|
super().__init__() |
|
|
self.config = config or LookingGlassConfig() |
|
|
self.encoder = _AWDLSTMEncoder(self.config) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset hidden states.""" |
|
|
self.encoder.reset() |
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. Returns last-token embeddings. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Embeddings (batch, hidden_size) |
|
|
""" |
|
|
return self.get_embeddings(input_ids) |
|
|
|
|
|
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Get sequence embeddings using last-token pooling (original LG method). |
|
|
|
|
|
Resets hidden state before encoding for deterministic results. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Embeddings (batch, hidden_size) |
|
|
""" |
|
|
self.encoder.reset() |
|
|
hidden = self.encoder(input_ids) |
|
|
return hidden[:, -1] |
|
|
|
|
|
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Get hidden states for all positions. |
|
|
|
|
|
Resets hidden state before encoding for deterministic results. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Hidden states (batch, seq_len, hidden_size) |
|
|
""" |
|
|
self.encoder.reset() |
|
|
return self.encoder(input_ids) |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
self.config.save_pretrained(save_directory) |
|
|
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlass": |
|
|
config = config or LookingGlassConfig.from_pretrained(pretrained_path) |
|
|
model = cls(config) |
|
|
|
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
|
|
else: |
|
|
model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
encoder_state_dict = {k: v for k, v in state_dict.items() |
|
|
if not k.startswith('lm_head.')} |
|
|
model.load_state_dict(encoder_state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class LookingGlassLM(nn.Module): |
|
|
""" |
|
|
LookingGlass with language modeling head. |
|
|
|
|
|
Full model for next-token prediction. Can also extract embeddings. |
|
|
|
|
|
Example: |
|
|
>>> model = LookingGlassLM.from_pretrained('lookingglass-v1') |
|
|
>>> tokenizer = LookingGlassTokenizer() |
|
|
>>> inputs = tokenizer("GATTACA", return_tensors=True) |
|
|
>>> logits = model(inputs['input_ids']) # (1, 8, 8) |
|
|
>>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104) |
|
|
""" |
|
|
|
|
|
config_class = LookingGlassConfig |
|
|
|
|
|
def __init__(self, config: Optional[LookingGlassConfig] = None): |
|
|
super().__init__() |
|
|
self.config = config or LookingGlassConfig() |
|
|
self.encoder = _AWDLSTMEncoder(self.config) |
|
|
self.lm_head = _LMHead( |
|
|
self.config, |
|
|
embed_tokens=self.encoder.embed_tokens if self.config.tie_weights else None |
|
|
) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset hidden states.""" |
|
|
self.encoder.reset() |
|
|
|
|
|
def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. Returns logits for next-token prediction. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Logits (batch, seq_len, vocab_size) |
|
|
""" |
|
|
hidden = self.encoder(input_ids) |
|
|
return self.lm_head(hidden) |
|
|
|
|
|
def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Get sequence embeddings using last-token pooling. |
|
|
|
|
|
Resets hidden state before encoding for deterministic results. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Embeddings (batch, hidden_size) |
|
|
""" |
|
|
self.encoder.reset() |
|
|
hidden = self.encoder(input_ids) |
|
|
return hidden[:, -1] |
|
|
|
|
|
def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor: |
|
|
""" |
|
|
Get hidden states for all positions. |
|
|
|
|
|
Resets hidden state before encoding for deterministic results. |
|
|
|
|
|
Args: |
|
|
input_ids: Token indices (batch, seq_len) |
|
|
|
|
|
Returns: |
|
|
Hidden states (batch, seq_len, hidden_size) |
|
|
""" |
|
|
self.encoder.reset() |
|
|
return self.encoder(input_ids) |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
self.config.save_pretrained(save_directory) |
|
|
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlassLM": |
|
|
config = config or LookingGlassConfig.from_pretrained(pretrained_path) |
|
|
model = cls(config) |
|
|
|
|
|
if _is_hf_hub_id(pretrained_path): |
|
|
model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
|
|
else: |
|
|
model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_original_weights(model: Union[LookingGlass, LookingGlassLM], weights_path: str) -> None: |
|
|
""" |
|
|
Load weights from original fastai-trained LookingGlass checkpoint. |
|
|
|
|
|
Args: |
|
|
model: Model to load weights into |
|
|
weights_path: Path to LookingGlass.pth or LookingGlass_enc.pth |
|
|
""" |
|
|
checkpoint = torch.load(weights_path, map_location='cpu') |
|
|
|
|
|
if 'model' in checkpoint: |
|
|
state_dict = checkpoint['model'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
is_lm_model = isinstance(model, LookingGlassLM) |
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if '.module.weight_hh_l0' in k: |
|
|
continue |
|
|
|
|
|
if k.startswith('0.'): |
|
|
new_k = k[2:] |
|
|
new_k = new_k.replace('encoder.', 'embed_tokens.') |
|
|
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.') |
|
|
new_k = new_k.replace('rnns.', 'layers.') |
|
|
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.') |
|
|
new_k = new_k.replace('input_dp.', 'input_dropout.') |
|
|
new_state_dict['encoder.' + new_k] = v |
|
|
|
|
|
elif k.startswith('1.') and is_lm_model: |
|
|
new_k = k[2:] |
|
|
new_k = new_k.replace('output_dp.', 'output_dropout.') |
|
|
new_state_dict['lm_head.' + new_k] = v |
|
|
|
|
|
else: |
|
|
new_k = k.replace('encoder.', 'embed_tokens.') |
|
|
new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.') |
|
|
new_k = new_k.replace('rnns.', 'layers.') |
|
|
new_k = new_k.replace('hidden_dps.', 'hidden_dropout.') |
|
|
new_k = new_k.replace('input_dp.', 'input_dropout.') |
|
|
new_state_dict['encoder.' + new_k] = v |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
|
|
|
def convert_checkpoint(input_path: str, output_dir: str) -> None: |
|
|
"""Convert original checkpoint to new format.""" |
|
|
config = LookingGlassConfig() |
|
|
model = LookingGlassLM(config) |
|
|
load_original_weights(model, input_path) |
|
|
model.save_pretrained(output_dir) |
|
|
|
|
|
tokenizer = LookingGlassTokenizer() |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
print(f"Saved to {output_dir}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='LookingGlass DNA Language Model') |
|
|
parser.add_argument('--convert', type=str, help='Convert original weights') |
|
|
parser.add_argument('--output', type=str, default='./lookingglass-v1', help='Output directory') |
|
|
parser.add_argument('--test', action='store_true', help='Run tests') |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.convert: |
|
|
convert_checkpoint(args.convert, args.output) |
|
|
|
|
|
elif args.test: |
|
|
print("Testing LookingGlass...\n") |
|
|
|
|
|
tokenizer = LookingGlassTokenizer() |
|
|
print(f"Vocab: {tokenizer.vocab}") |
|
|
print(f"BOS token added: {tokenizer.add_bos_token}") |
|
|
print(f"EOS token added: {tokenizer.add_eos_token}") |
|
|
|
|
|
inputs = tokenizer("GATTACA", return_tensors=True) |
|
|
print(f"\nTokenized 'GATTACA': {inputs['input_ids']}") |
|
|
print(f"Decoded: {tokenizer.decode(inputs['input_ids'][0])}") |
|
|
|
|
|
config = LookingGlassConfig() |
|
|
print(f"\nConfig: bidirectional={config.bidirectional}") |
|
|
|
|
|
|
|
|
encoder = LookingGlass(config) |
|
|
print(f"\nLookingGlass params: {sum(p.numel() for p in encoder.parameters()):,}") |
|
|
|
|
|
encoder.eval() |
|
|
with torch.no_grad(): |
|
|
emb = encoder.get_embeddings(inputs['input_ids']) |
|
|
print(f"Embeddings shape: {emb.shape}") |
|
|
|
|
|
|
|
|
lm = LookingGlassLM(config) |
|
|
print(f"\nLookingGlassLM params: {sum(p.numel() for p in lm.parameters()):,}") |
|
|
|
|
|
lm.eval() |
|
|
with torch.no_grad(): |
|
|
logits = lm(inputs['input_ids']) |
|
|
emb = lm.get_embeddings(inputs['input_ids']) |
|
|
print(f"Logits shape: {logits.shape}") |
|
|
print(f"Embeddings shape: {emb.shape}") |
|
|
|
|
|
print("\nAll tests passed!") |
|
|
|
|
|
else: |
|
|
parser.print_help() |
|
|
|