|
from .configuration_dna2vec import DNAEncoderConfig |
|
from transformers import PreTrainedModel |
|
import math |
|
from typing import Literal, Optional |
|
import torch |
|
import torch.nn as nn |
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
vocab_size: int = 4, |
|
embedding_dim: int = 384, |
|
dim_feedforward: int = 1536, |
|
num_heads: int = 12, |
|
num_layers: int = 6, |
|
dropout: float = 0.1, |
|
activation: Literal["relu", "gelu"] = "gelu", |
|
pos_embedding: Optional[str] = "SinusoidalPositionalEncoding", |
|
max_position_embeddings: int = 1024, |
|
): |
|
""" |
|
Default values taken from miniLM v6 |
|
https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json |
|
""" |
|
super().__init__() |
|
self.vocab_size = vocab_size |
|
self.embedding_dim = embedding_dim |
|
self.dropout = dropout |
|
self.num_heads = num_heads |
|
self.num_layers = num_layers |
|
self.emb_dropout = nn.Dropout(p=dropout) |
|
|
|
if pos_embedding == "SinusoidalPositionalEncoding": |
|
position = torch.arange(max_position_embeddings).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim) |
|
) |
|
pe = torch.zeros(max_position_embeddings, 1, embedding_dim) |
|
pe[:, 0, 0::2] = torch.sin(position * div_term) |
|
pe[:, 0, 1::2] = torch.cos(position * div_term) |
|
pe = pe.squeeze(1).unsqueeze(0) |
|
self.register_buffer("positional_embedding", pe) |
|
else: |
|
raise ValueError(f"Positional embedding {pos_embedding} not found") |
|
|
|
self.embedding = nn.Embedding( |
|
num_embeddings=vocab_size, |
|
embedding_dim=embedding_dim, |
|
) |
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=embedding_dim, |
|
nhead=num_heads, |
|
dim_feedforward=dim_feedforward, |
|
dropout=dropout, |
|
activation=activation, |
|
batch_first=True, |
|
norm_first=True, |
|
) |
|
self.trf_encoder = nn.TransformerEncoder( |
|
encoder_layer=encoder_layer, num_layers=num_layers |
|
) |
|
|
|
def forward( |
|
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
emb = self.emb_dropout( |
|
self.embedding(input_ids) + self.positional_embedding[:, :input_ids.size(1), :] |
|
) |
|
|
|
|
|
|
|
attn = None |
|
if attention_mask is not None: |
|
attn = attention_mask == 0 |
|
out = self.trf_encoder(emb, src_key_padding_mask=attn) |
|
|
|
return out |
|
|
|
class DNAEncoder(PreTrainedModel): |
|
config_class = DNAEncoderConfig |
|
|
|
def __init__(self, config: DNAEncoderConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.encoder = Encoder( |
|
vocab_size=config.vocab_size, |
|
embedding_dim=config.embedding_dim, |
|
dim_feedforward=config.dim_feedforward, |
|
num_heads=config.num_heads, |
|
num_layers=config.num_layers, |
|
dropout=config.dropout, |
|
activation=config.activation, |
|
max_position_embeddings=config.max_position_embeddings, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
return self.encoder(input_ids, attention_mask) |