| """
|
| MolecularModule: Domain knowledge for chemistry and biology.
|
| Element embeddings, SMILES understanding, bond types, amino acids.
|
| """
|
|
|
| import re
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple, List
|
|
|
|
|
| class MolecularModule(nn.Module):
|
| """
|
| Domain knowledge for chemistry and biology.
|
| - All 118 elements as learned embeddings with properties
|
| (atomic number, mass, electronegativity, valence electrons)
|
| - SMILES string understanding for molecular structures
|
| - Bond type awareness (covalent, ionic, hydrogen, van der Waals)
|
| - Amino acid sequence understanding for biology/zoology
|
| - Molecular formula → property reasoning
|
| """
|
|
|
| def __init__(self, d_model: int, num_elements: int = 118):
|
| """
|
| Initialize MolecularModule.
|
|
|
| Args:
|
| d_model: Model dimension
|
| num_elements: Number of chemical elements (default 118)
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
| self.num_elements = num_elements
|
|
|
|
|
| self.element_embed = nn.Embedding(num_elements + 1, d_model)
|
|
|
|
|
|
|
|
|
|
|
| self.property_proj = nn.Linear(12, d_model)
|
|
|
|
|
|
|
|
|
| self.bond_embed = nn.Embedding(8, d_model)
|
|
|
|
|
| self.amino_acid_vocab = 25
|
| self.amino_embed = nn.Embedding(self.amino_acid_vocab, d_model)
|
|
|
|
|
| self.mol_attention = nn.MultiheadAttention(
|
| d_model,
|
| num_heads=8,
|
| batch_first=True,
|
| dropout=0.1,
|
| )
|
|
|
|
|
| self.property_head = nn.Linear(d_model, 12)
|
|
|
|
|
| self._initialize_weights()
|
|
|
|
|
| self._init_element_properties()
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights."""
|
| for module in [self.element_embed, self.property_proj, self.bond_embed,
|
| self.amino_embed, self.mol_attention, self.property_head]:
|
| if hasattr(module, 'weight'):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if hasattr(module, 'bias') and module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| def _init_element_properties(self):
|
| """Initialize element property table with approximate values."""
|
|
|
|
|
|
|
|
|
| properties = torch.zeros(self.num_elements + 1, 12)
|
|
|
|
|
|
|
| element_data = {
|
| 1: [1, 1.008, 2.20, 1, 1, 1, 25, 1312, 72.8, 0.0000899, 14, 20],
|
| 6: [6, 12.011, 2.55, 4, 2, 14, 70, 1086, 153.9, 2.267, 3550, 4027],
|
| 7: [7, 14.007, 3.04, 5, 2, 15, 65, 1402, 7.0, 0.0012506, 63, 77],
|
| 8: [8, 15.999, 3.44, 6, 2, 16, 60, 1314, 141.0, 0.001429, 55, 90],
|
|
|
| }
|
|
|
| for z, props in element_data.items():
|
| properties[z] = torch.tensor(props)
|
|
|
| self.register_buffer("element_properties", properties)
|
|
|
| def detect_molecular_spans(
|
| self,
|
| text: str,
|
| ) -> List[Tuple[int, int, str]]:
|
| """
|
| Detect molecular/chemical spans in text.
|
|
|
| Args:
|
| text: Input text string
|
|
|
| Returns:
|
| List of (start_char, end_char, span_type)
|
| span_type: "formula", "smiles", "amino_acid"
|
| """
|
| spans = []
|
|
|
|
|
| formula_pattern = r'\b([A-Z][a-z]?\d*)+(?:[A-Z][a-z]?\d*)*\b'
|
| for match in re.finditer(formula_pattern, text):
|
|
|
| span = match.group()
|
| if len(span) > 1 or span.isupper():
|
| spans.append((match.start(), match.end(), "formula"))
|
|
|
|
|
|
|
| smiles_hints = ['=', '#', '@', '[', ']', '(', ')']
|
| words = re.findall(r'\S+', text)
|
| for word in words:
|
| if any(hint in word for hint in smiles_hints) and len(word) > 3:
|
|
|
| pos = text.find(word)
|
| if pos >= 0:
|
| spans.append((pos, pos + len(word), "smiles"))
|
|
|
|
|
| aa_pattern = r'\b([ACDEFGHIKLMNPQRSTVWY]{6,})\b'
|
| for match in re.finditer(aa_pattern, text.upper()):
|
| spans.append((match.start(), match.end(), "amino_acid"))
|
|
|
| return spans
|
|
|
| def encode_molecule(
|
| self,
|
| formula: str,
|
| ) -> torch.Tensor:
|
| """
|
| Encode a molecular formula into embedding.
|
|
|
| Args:
|
| formula: Chemical formula string (e.g., "C6H12O6")
|
|
|
| Returns:
|
| Molecule embedding (d_model,)
|
| """
|
|
|
|
|
| pattern = r'([A-Z][a-z]?)(\d*)'
|
| matches = re.findall(pattern, formula)
|
|
|
| device = self.element_embed.weight.device
|
| embeddings = []
|
| weights = []
|
|
|
| for element, count_str in matches:
|
|
|
| element_map = {
|
| 'H': 1, 'He': 2, 'Li': 3, 'Be': 4, 'B': 5, 'C': 6, 'N': 7, 'O': 8,
|
| 'F': 9, 'Ne': 10, 'Na': 11, 'Mg': 12, 'Al': 13, 'Si': 14, 'P': 15,
|
| 'S': 16, 'Cl': 17, 'Ar': 18, 'K': 19, 'Ca': 20,
|
|
|
| }
|
| z = element_map.get(element, 0)
|
|
|
| count = int(count_str) if count_str else 1
|
|
|
|
|
| elem_emb = self.element_embed(torch.tensor(z, device=device))
|
|
|
|
|
| props = self.element_properties[z].unsqueeze(0)
|
| props_emb = self.property_proj(props).squeeze(0)
|
|
|
|
|
| combined = elem_emb + props_emb
|
| embeddings.append(combined)
|
| weights.append(count)
|
|
|
| if not embeddings:
|
|
|
| return torch.zeros(self.d_model, device=device)
|
|
|
|
|
| embeddings = torch.stack(embeddings)
|
| weights = torch.tensor(weights, dtype=torch.float32, device=device)
|
| weights = weights / weights.sum()
|
|
|
| return (embeddings * weights.unsqueeze(-1)).sum(dim=0)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| text: Optional[List[str]] = None,
|
| molecular_spans: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through molecular module.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| text: Optional original text strings
|
| molecular_spans: Optional pre-computed molecular spans per batch
|
|
|
| Returns:
|
| Molecular-enhanced representation (batch, seq_len, d_model)
|
| """
|
| batch, seq_len, d_model = x.shape
|
| device = x.device
|
|
|
|
|
| if molecular_spans is None and text is not None:
|
| molecular_spans = []
|
| for b in range(batch):
|
| spans = self.detect_molecular_spans(text[b])
|
|
|
| token_spans = []
|
| for start_char, end_char, span_type in spans:
|
| start_tok = max(0, start_char // 4)
|
| end_tok = min(seq_len, end_char // 4 + 1)
|
| token_spans.append((start_tok, end_tok, span_type))
|
| molecular_spans.append(token_spans)
|
|
|
|
|
| output = x.clone()
|
|
|
| if molecular_spans:
|
| for b in range(batch):
|
| spans_b = molecular_spans[b] if b < len(molecular_spans) else []
|
|
|
| for start_tok, end_tok, span_type in spans_b:
|
| if end_tok <= start_tok:
|
| continue
|
|
|
| span_slice = x[b, start_tok:end_tok, :]
|
|
|
| if span_type == "formula":
|
|
|
| if text:
|
| formula = text[b][start_tok*4:end_tok*4]
|
| mol_emb = self.encode_molecule(formula)
|
| else:
|
| mol_emb = torch.randn(d_model, device=device)
|
|
|
|
|
| output[b, start_tok, :] += mol_emb
|
|
|
| elif span_type == "amino_acid":
|
|
|
|
|
| seq_len_span = end_tok - start_tok
|
| aa_ids = torch.randint(0, 20, (seq_len_span,), device=device)
|
| aa_emb = self.amino_embed(aa_ids)
|
| output[b, start_tok:end_tok, :] += aa_emb
|
|
|
| elif span_type == "smiles":
|
|
|
|
|
| seq_len_span = end_tok - start_tok
|
| if seq_len_span > 1:
|
|
|
| attn_out, _ = self.mol_attention(
|
| span_slice.unsqueeze(0),
|
| span_slice.unsqueeze(0),
|
| span_slice.unsqueeze(0),
|
| )
|
| output[b, start_tok:end_tok, :] += attn_out.squeeze(0)
|
|
|
| return output
|
|
|
| def compute_property_loss(
|
| self,
|
| x: torch.Tensor,
|
| element_ids: torch.Tensor,
|
| target_properties: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Compute auxiliary loss for property prediction.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| element_ids: Element IDs (batch, seq_len)
|
| target_properties: Target property values (batch, seq_len, 12)
|
|
|
| Returns:
|
| MSE loss for property prediction
|
| """
|
|
|
| elem_emb = self.element_embed(element_ids)
|
|
|
|
|
| pred_props = self.property_head(elem_emb)
|
|
|
|
|
| loss = F.mse_loss(pred_props, target_properties)
|
| return loss
|
|
|
|
|
| def test_molecular_module():
|
| """Test MolecularModule."""
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 128
|
|
|
| module = MolecularModule(d_model)
|
|
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = [
|
| "Water is H2O. The DNA sequence is ACGTACGTACGT.",
|
| "Proteins are made of amino acids like ACDEFGH. Benzene is C6H6."
|
| ]
|
|
|
| output = module(x, text=text)
|
| print(f"Input shape: {x.shape}")
|
| print(f"Output shape: {output.shape}")
|
| assert output.shape == x.shape
|
|
|
| print("MolecularModule test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_molecular_module()
|
|
|