| | """
|
| | DomainClassifier: Classifies documents into 7 science domains.
|
| | Uses a simple linear classifier on top of text features.
|
| | """
|
| |
|
| | import re
|
| | from typing import List, Tuple, Optional
|
| | import torch
|
| | import torch.nn as nn
|
| |
|
| |
|
| | class DomainClassifier(nn.Module):
|
| | """
|
| | Classifies documents into 7 science domains:
|
| | 0: Physics
|
| | 1: Mathematics
|
| | 2: Chemistry
|
| | 3: Biology
|
| | 4: Earth Science
|
| | 5: Space Science
|
| | 6: Zoology
|
| | """
|
| |
|
| |
|
| | DOMAIN_KEYWORDS = {
|
| | 0: ['physics', 'quantum', 'relativity', 'mechanics', 'thermodynamics', 'electromagnetism'],
|
| | 1: ['mathematics', 'algebra', 'calculus', 'geometry', 'topology', 'proof', 'theorem'],
|
| | 2: ['chemistry', 'molecular', 'reaction', 'compound', 'element', 'organic'],
|
| | 3: ['biology', 'cell', 'gene', 'protein', 'organism', 'evolution'],
|
| | 4: ['earth', 'geology', 'climate', 'ocean', 'atmosphere', 'meteorology'],
|
| | 5: ['space', 'astronomy', 'planet', 'star', 'galaxy', 'cosmology'],
|
| | 6: ['zoology', 'animal', 'species', 'vertebrate', 'invertebrate', 'ecology'],
|
| | }
|
| |
|
| | def __init__(self, d_model: int, num_domains: int = 7):
|
| | """
|
| | Initialize domain classifier.
|
| |
|
| | Args:
|
| | d_model: Input embedding dimension
|
| | num_domains: Number of domains (7)
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.num_domains = num_domains
|
| |
|
| |
|
| | self.classifier = nn.Linear(d_model, num_domains)
|
| |
|
| |
|
| | nn.init.normal_(self.classifier.weight, mean=0.0, std=0.02)
|
| | nn.init.zeros_(self.classifier.bias)
|
| |
|
| | def forward(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Classify domain from hidden states.
|
| |
|
| | Args:
|
| | hidden_states: (batch, seq_len, d_model)
|
| | attention_mask: (batch, seq_len)
|
| |
|
| | Returns:
|
| | Domain logits (batch, num_domains)
|
| | """
|
| |
|
| | if attention_mask is not None:
|
| | mask = attention_mask.unsqueeze(-1)
|
| | summed = (hidden_states * mask).sum(dim=1)
|
| | counts = mask.sum(dim=1)
|
| | pooled = summed / counts.clamp(min=1)
|
| | else:
|
| | pooled = hidden_states.mean(dim=1)
|
| |
|
| |
|
| | logits = self.classifier(pooled)
|
| | return logits
|
| |
|
| | def classify_text(
|
| | self,
|
| | text: str,
|
| | ) -> Tuple[int, float]:
|
| | """
|
| | Rule-based fallback classification from raw text.
|
| |
|
| | Args:
|
| | text: Input text string
|
| |
|
| | Returns:
|
| | (domain_id, confidence)
|
| | """
|
| | text_lower = text.lower()
|
| |
|
| |
|
| | scores = []
|
| | for domain_id, keywords in self.DOMAIN_KEYWORDS.items():
|
| | score = sum(1 for kw in keywords if kw in text_lower)
|
| | scores.append(score)
|
| |
|
| | if max(scores) == 0:
|
| | return 0, 0.0
|
| |
|
| | best_domain = scores.index(max(scores))
|
| | confidence = max(scores) / sum(scores) if sum(scores) > 0 else 0.0
|
| |
|
| | return best_domain, confidence
|
| |
|
| | def compute_loss(
|
| | self,
|
| | logits: torch.Tensor,
|
| | domain_labels: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Compute classification loss.
|
| |
|
| | Args:
|
| | logits: (batch, num_domains)
|
| | domain_labels: (batch,) with domain IDs
|
| |
|
| | Returns:
|
| | Cross-entropy loss
|
| | """
|
| | return nn.functional.cross_entropy(logits, domain_labels)
|
| |
|
| |
|
| | def test_domain_classifier():
|
| | """Test DomainClassifier."""
|
| | d_model = 512
|
| | batch_size = 4
|
| | seq_len = 128
|
| |
|
| | classifier = DomainClassifier(d_model)
|
| |
|
| |
|
| | hidden = torch.randn(batch_size, seq_len, d_model)
|
| | logits = classifier(hidden)
|
| | print(f"Logits shape: {logits.shape}")
|
| | assert logits.shape == (batch_size, 7)
|
| |
|
| |
|
| | texts = [
|
| | "The quantum mechanics of particles...",
|
| | "Solving differential equations...",
|
| | "Chemical reactions produce compounds...",
|
| | "Cells contain DNA and proteins...",
|
| | ]
|
| | for text in texts:
|
| | domain, conf = classifier.classify_text(text)
|
| | print(f"Text: {text[:30]}... -> Domain {domain}, conf {conf:.2f}")
|
| |
|
| |
|
| | labels = torch.tensor([0, 1, 2, 3])
|
| | loss = classifier.compute_loss(logits, labels)
|
| | print(f"Loss: {loss.item():.4f}")
|
| |
|
| | print("DomainClassifier test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_domain_classifier()
|
| |
|