| """
|
| NumericalReasoningModule: Handles scientific numerical reasoning.
|
| Digit-level number encoding, scientific notation, unit awareness.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import re
|
| from typing import Optional, Tuple, List
|
|
|
|
|
| class NumericalReasoningModule(nn.Module):
|
| """
|
| Handles scientific numerical reasoning.
|
| - Digit-level number encoding (each digit gets position-aware embedding)
|
| - Scientific notation understanding (6.02 × 10²³)
|
| - Unit awareness (meters, joules, moles, kelvin)
|
| - Order of magnitude reasoning
|
| - Significant figures tracking
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| max_digits: int = 20,
|
| num_units: int = 256,
|
| ):
|
| """
|
| Initialize NumericalReasoningModule.
|
|
|
| Args:
|
| d_model: Model dimension
|
| max_digits: Maximum number of digits to encode
|
| num_units: Number of unit types to embed
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
| self.max_digits = max_digits
|
|
|
|
|
| self.digit_embed = nn.Embedding(10, 64)
|
|
|
|
|
| self.position_embed = nn.Embedding(max_digits, 64)
|
|
|
|
|
| self.number_proj = nn.Linear(128, d_model)
|
|
|
|
|
| self.unit_embed = nn.Embedding(num_units, d_model)
|
|
|
|
|
| self.sci_notation = nn.Linear(d_model * 2, d_model)
|
|
|
|
|
| self.magnitude_embed = nn.Embedding(21, d_model)
|
|
|
|
|
| self._initialize_weights()
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights."""
|
| for module in [self.digit_embed, self.position_embed, self.number_proj,
|
| self.unit_embed, self.sci_notation, self.magnitude_embed]:
|
| if hasattr(module, 'weight'):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if hasattr(module, 'bias') and module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| def encode_number(
|
| self,
|
| number_str: str,
|
| device: torch.device,
|
| ) -> torch.Tensor:
|
| """
|
| Encode a number string using digit-level encoding.
|
|
|
| Args:
|
| number_str: String representation of number (e.g., "123.45e-6")
|
| device: Torch device
|
|
|
| Returns:
|
| Number embedding (d_model,)
|
| """
|
|
|
| digits = [int(d) for d in re.findall(r'\d', number_str)]
|
| if not digits:
|
| digits = [0]
|
|
|
|
|
| if len(digits) > self.max_digits:
|
| digits = digits[:self.max_digits]
|
| else:
|
| digits = digits + [0] * (self.max_digits - len(digits))
|
|
|
| digits_tensor = torch.tensor(digits, device=device)
|
| positions = torch.arange(self.max_digits, device=device)
|
|
|
|
|
| digit_emb = self.digit_embed(digits_tensor)
|
| pos_emb = self.position_embed(positions)
|
|
|
|
|
| combined = torch.cat([digit_emb, pos_emb], dim=-1)
|
| number_emb = self.number_proj(combined)
|
|
|
|
|
| return number_emb.mean(dim=0)
|
|
|
| def detect_numbers(
|
| self,
|
| text: str,
|
| ) -> List[Tuple[str, int, int, Optional[str]]]:
|
| """
|
| Detect numbers in text with optional units and scientific notation.
|
|
|
| Returns:
|
| List of (number_str, start_char, end_char, unit_str)
|
| """
|
|
|
|
|
| pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?'
|
|
|
| matches = []
|
| for match in re.finditer(pattern, text):
|
| number_str = match.group(1)
|
| unit_str = match.group(2) if match.group(2) else None
|
| matches.append((number_str, match.start(), match.end(), unit_str))
|
|
|
| return matches
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| text: Optional[List[str]] = None,
|
| number_positions: Optional[List[List[Tuple[int, int, str]]]] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass through numerical reasoning module.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| text: Optional original text strings
|
| number_positions: Optional list of (start_token, end_token, number_str) per batch
|
|
|
| Returns:
|
| Numerical-enhanced representation (batch, seq_len, d_model)
|
| """
|
| batch, seq_len, d_model = x.shape
|
| device = x.device
|
|
|
|
|
| if number_positions is None and text is not None:
|
| number_positions = []
|
| for b in range(batch):
|
| numbers = self.detect_numbers(text[b])
|
|
|
| token_nums = []
|
| for num_str, start_char, end_char, unit_str in numbers:
|
| start_tok = max(0, start_char // 4)
|
| end_tok = min(seq_len, end_char // 4 + 1)
|
| token_nums.append((start_tok, end_tok, num_str, unit_str))
|
| number_positions.append(token_nums)
|
|
|
|
|
| output = x.clone()
|
|
|
| if number_positions:
|
| for b in range(batch):
|
| nums_b = number_positions[b] if b < len(number_positions) else []
|
|
|
| for start_tok, end_tok, num_str, unit_str in nums_b:
|
| if end_tok <= start_tok or start_tok >= seq_len:
|
| continue
|
|
|
|
|
| start_tok = min(start_tok, seq_len - 1)
|
| end_tok = min(end_tok, seq_len)
|
|
|
|
|
| number_emb = self.encode_number(num_str, device)
|
|
|
|
|
| if unit_str:
|
|
|
| unit_id = hash(unit_str) % self.unit_embed.num_embeddings
|
| unit_emb = self.unit_embed(torch.tensor(unit_id, device=device))
|
| number_emb = number_emb + unit_emb
|
|
|
|
|
| if 'e' in num_str.lower() or '×10' in num_str:
|
|
|
| exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str)
|
| if exp_match:
|
| exp = int(exp_match.group(1) or exp_match.group(2))
|
| exp = max(-10, min(10, exp))
|
| magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device))
|
| number_emb = number_emb + magnitude_emb
|
|
|
|
|
| output[b, start_tok, :] += number_emb
|
|
|
| return output
|
|
|
| def compute_numerical_loss(
|
| self,
|
| x: torch.Tensor,
|
| number_mask: torch.Tensor,
|
| target_values: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Compute auxiliary loss for numerical reasoning.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| number_mask: Mask for number tokens (batch, seq_len)
|
| target_values: Target numeric values (batch, seq_len) or None
|
|
|
| Returns:
|
| MSE loss for value prediction (simplified)
|
| """
|
|
|
|
|
| return 0.0
|
|
|
|
|
| def test_numerical_module():
|
| """Test NumericalReasoningModule."""
|
| d_model = 512
|
| batch_size = 2
|
| seq_len = 128
|
|
|
| module = NumericalReasoningModule(d_model)
|
|
|
| x = torch.randn(batch_size, seq_len, d_model)
|
| text = [
|
| "The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.",
|
| "Calculate: 123.45 + 67.89 = ? The answer is 191.34."
|
| ]
|
|
|
| output = module(x, text=text)
|
| print(f"Input shape: {x.shape}")
|
| print(f"Output shape: {output.shape}")
|
| assert output.shape == x.shape
|
|
|
| print("NumericalReasoningModule test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_numerical_module()
|
|
|