Vortex-13b-V1 / data /domain_classifier.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
DomainClassifier: Classifies documents into 7 science domains.
Uses a simple linear classifier on top of text features.
"""
import re
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
class DomainClassifier(nn.Module):
"""
Classifies documents into 7 science domains:
0: Physics
1: Mathematics
2: Chemistry
3: Biology
4: Earth Science
5: Space Science
6: Zoology
"""
# Domain keywords for rule-based fallback
DOMAIN_KEYWORDS = {
0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'],
1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'],
2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'],
3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'],
4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'],
5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'],
6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'],
}
def __init__(self, d_model: int, num_domains: int = 7):
"""
Initialize domain classifier.
Args:
d_model: Input embedding dimension
num_domains: Number of domains (7)
"""
super().__init__()
self.d_model = d_model
self.num_domains = num_domains
# Simple linear classifier
self.classifier = nn.Linear(d_model, num_domains)
# Initialize weights
nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Classify domain from hidden states.
Args:
hidden_states: (batch, seq_len, d_model)
attention_mask: (batch, seq_len)
Returns:
Domain logits (batch, num_domains)
"""
# Mean pooling over sequence (masked)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1) # (batch, seq_len, 1)
summed = (hidden_states * mask).sum(dim=1)
counts = mask.sum(dim=1)
pooled = summed / counts.clamp(min=1)
else:
pooled = hidden_states.mean(dim=1)
# Classify
logits = self.classifier(pooled)
return logits
def classify_text(
self,
text: str,
) -> Tuple[int, float]:
"""
Rule-based fallback classification from raw text.
Args:
text: Input text string
Returns:
(domain_id, confidence)
"""
text_lower = text.lower()
# Count keyword matches per domain
scores = []
for domain_id, keywords in self.DOMAIN_KEYWORDS.items():
score = sum(1 for kw in keywords if kw in text_lower)
scores.append(score)
if max(scores) == 0:
return 0, 0.0 # Unknown -> default to physics
best_domain = scores.index(max(scores))
confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0
return best_domain, confidence
def compute_loss(
self,
logits: torch.Tensor,
domain_labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute classification loss.
Args:
logits: (batch, num_domains)
domain_labels: (batch,) with domain IDs
Returns:
Cross-entropy loss
"""
return nn.functional.cross_entropy(logits, domain_labels)
def test_domain_classifier():
"""Test DomainClassifier."""
d_model = 512
batch_size = 4
seq_len = 128
classifier = DomainClassifier(d_model)
# Test with random hidden states
hidden = torch.randn(batch_size, seq_len, d_model)
logits = classifier(hidden)
print(f"Logits shape: {logits.shape}")
assert logits.shape == (batch_size, 7)
# Test with text
texts = [
"The quantum mechanics of particles...",
"Solving differential equations...",
"Chemical reactions produce compounds...",
"Cells contain DNA and proteins...",
]
for text in texts:
domain, conf = classifier.classify_text(text)
print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}")
# Test loss
labels = torch.tensor([0, 1, 2, 3])
loss = classifier.compute_loss(logits, labels)
print(f"Loss: {loss.item():.4f}")
print("DomainClassifier test passed!")
if __name__ == "__main__":
test_domain_classifier()