Vortex-13b-V1 / models /science_modules /numerical_module.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
NumericalReasoningModule: Handles scientific numerical reasoning.
Digit-level number encoding, scientific notation, unit awareness.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from typing import Optional, Tuple, List
class NumericalReasoningModule(nn.Module):
"""
Handles scientific numerical reasoning.
- Digit-level number encoding (each digit gets position-aware embedding)
- Scientific notation understanding (6.02 × 10²³)
- Unit awareness (meters, joules, moles, kelvin)
- Order of magnitude reasoning
- Significant figures tracking
"""
def __init__(
self,
d_model: int,
max_digits: int = 20,
num_units: int = 256,
):
"""
Initialize NumericalReasoningModule.
Args:
d_model: Model dimension
max_digits: Maximum number of digits to encode
num_units: Number of unit types to embed
"""
super().__init__()
self.d_model = d_model
self.max_digits = max_digits
# Digit embeddings (0-9)
self.digit_embed = nn.Embedding(10, 64)
# Position embeddings (ones, tens, hundreds...)
self.position_embed = nn.Embedding(max_digits, 64)
# Project digit+position to model dimension
self.number_proj = nn.Linear(128, d_model)
# Unit embedding (SI units + common scientific units)
self.unit_embed = nn.Embedding(num_units, d_model)
# Scientific notation handler
self.sci_notation = nn.Linear(d_model * 2, d_model)
# Magnitude embedding (powers of 10: -10 to +10)
self.magnitude_embed = nn.Embedding(21, d_model) # -10 to +10
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.digit_embed, self.position_embed, self.number_proj,
self.unit_embed, self.sci_notation, self.magnitude_embed]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def encode_number(
self,
number_str: str,
device: torch.device,
) -> torch.Tensor:
"""
Encode a number string using digit-level encoding.
Args:
number_str: String representation of number (e.g., "123.45e-6")
device: Torch device
Returns:
Number embedding (d_model,)
"""
# Extract digits (ignore decimal point, sign, exponent)
digits = [int(d) for d in re.findall(r'\d', number_str)]
if not digits:
digits = [0]
# Pad/truncate to max_digits
if len(digits) > self.max_digits:
digits = digits[:self.max_digits]
else:
digits = digits + [0] * (self.max_digits - len(digits))
digits_tensor = torch.tensor(digits, device=device) # (max_digits,)
positions = torch.arange(self.max_digits, device=device) # (max_digits,)
# Embed digits and positions
digit_emb = self.digit_embed(digits_tensor) # (max_digits, 64)
pos_emb = self.position_embed(positions) # (max_digits, 64)
# Concatenate and project
combined = torch.cat([digit_emb, pos_emb], dim=-1) # (max_digits, 128)
number_emb = self.number_proj(combined) # (max_digits, d_model)
# Mean pool over positions
return number_emb.mean(dim=0) # (d_model,)
def detect_numbers(
self,
text: str,
) -> List[Tuple[str, int, int, Optional[str]]]:
"""
Detect numbers in text with optional units and scientific notation.
Returns:
List of (number_str, start_char, end_char, unit_str)
"""
# Pattern: number with optional decimal, exponent, and unit
# Matches: 123, 123.45, 1.23e-4, 6.02×10²³, 100 m, 5.0 J/mol
pattern = r'(\d+(?:\.\d+)?(?:[eE][+-]?\d+)?(?:×10\^?[+-]?\d+)?)(?:\s*([a-zA-Z°%]+))?'
matches = []
for match in re.finditer(pattern, text):
number_str = match.group(1)
unit_str = match.group(2) if match.group(2) else None
matches.append((number_str, match.start(), match.end(), unit_str))
return matches
def forward(
self,
x: torch.Tensor,
text: Optional[List[str]] = None,
number_positions: Optional[List[List[Tuple[int, int, str]]]] = None,
) -> torch.Tensor:
"""
Forward pass through numerical reasoning module.
Args:
x: Input tensor (batch, seq_len, d_model)
text: Optional original text strings
number_positions: Optional list of (start_token, end_token, number_str) per batch
Returns:
Numerical-enhanced representation (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
device = x.device
# Detect numbers if text provided
if number_positions is None and text is not None:
number_positions = []
for b in range(batch):
numbers = self.detect_numbers(text[b])
# Convert char positions to token positions (approximate)
token_nums = []
for num_str, start_char, end_char, unit_str in numbers:
start_tok = max(0, start_char // 4)
end_tok = min(seq_len, end_char // 4 + 1)
token_nums.append((start_tok, end_tok, num_str, unit_str))
number_positions.append(token_nums)
# Enhance number spans
output = x.clone()
if number_positions:
for b in range(batch):
nums_b = number_positions[b] if b < len(number_positions) else []
for start_tok, end_tok, num_str, unit_str in nums_b:
if end_tok <= start_tok or start_tok >= seq_len:
continue
# Clamp to sequence bounds
start_tok = min(start_tok, seq_len - 1)
end_tok = min(end_tok, seq_len)
# Encode the number
number_emb = self.encode_number(num_str, device) # (d_model,)
# Add unit embedding if present
if unit_str:
# Simple hash-based unit ID (in practice would have unit vocab)
unit_id = hash(unit_str) % self.unit_embed.num_embeddings
unit_emb = self.unit_embed(torch.tensor(unit_id, device=device))
number_emb = number_emb + unit_emb
# Add magnitude embedding for scientific notation
if 'e' in num_str.lower() or '×10' in num_str:
# Extract exponent
exp_match = re.search(r'[eE]([+-]?\d+)|×10\^?([+-]?\d+)', num_str)
if exp_match:
exp = int(exp_match.group(1) or exp_match.group(2))
exp = max(-10, min(10, exp)) # Clamp to embedding range
magnitude_emb = self.magnitude_embed(torch.tensor(exp + 10, device=device))
number_emb = number_emb + magnitude_emb
# Add to the first token of the number span
output[b, start_tok, :] += number_emb
return output
def compute_numerical_loss(
self,
x: torch.Tensor,
number_mask: torch.Tensor,
target_values: torch.Tensor,
) -> torch.Tensor:
"""
Compute auxiliary loss for numerical reasoning.
Args:
x: Input tensor (batch, seq_len, d_model)
number_mask: Mask for number tokens (batch, seq_len)
target_values: Target numeric values (batch, seq_len) or None
Returns:
MSE loss for value prediction (simplified)
"""
# This is a simplified loss - in practice would have a value prediction head
# For now, return a small regularization loss on number embeddings
return 0.0
def test_numerical_module():
"""Test NumericalReasoningModule."""
d_model = 512
batch_size = 2
seq_len = 128
module = NumericalReasoningModule(d_model)
x = torch.randn(batch_size, seq_len, d_model)
text = [
"The speed of light is 2.998×10^8 m/s and Planck's constant is 6.626×10^-34 J·s.",
"Calculate: 123.45 + 67.89 = ? The answer is 191.34."
]
output = module(x, text=text)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape
print("NumericalReasoningModule test passed!")
if __name__ == "__main__":
test_numerical_module()