AI_Text_Authenticator / processors /domain_classifier.py
satyaki-mitra's picture
wip
d92e2aa
raw
history blame
13.8 kB
# DEPENDENCIES
from typing import Dict
from typing import List
from typing import Tuple
from loguru import logger
from typing import Optional
from dataclasses import dataclass
from config.threshold_config import Domain
from models.model_manager import get_model_manager
from config.threshold_config import interpolate_thresholds
from config.threshold_config import get_threshold_for_domain
@dataclass
class DomainPrediction:
"""
Result of domain classification
"""
primary_domain : Domain
secondary_domain : Optional[Domain]
confidence : float
domain_scores : Dict[str, float]
class DomainClassifier:
"""
Classifies text into domains using zero-shot classification
"""
# Enhanced domain labels for zero-shot classification
DOMAIN_LABELS = {Domain.ACADEMIC : ["academic paper", "research article", "scientific paper", "scholarly writing", "thesis", "dissertation", "academic research"],
Domain.CREATIVE : ["creative writing", "fiction", "story", "narrative", "poetry", "literary work", "imaginative writing"],
Domain.AI_ML : ["artificial intelligence", "machine learning", "neural networks", "data science", "AI research", "deep learning"],
Domain.SOFTWARE_DEV : ["software development", "programming", "coding", "software engineering", "web development", "application development"],
Domain.TECHNICAL_DOC : ["technical documentation", "user manual", "API documentation", "technical guide", "system documentation"],
Domain.ENGINEERING : ["engineering document", "technical design", "engineering analysis", "mechanical engineering", "electrical engineering"],
Domain.SCIENCE : ["scientific research", "physics", "chemistry", "biology", "scientific study", "experimental results"],
Domain.BUSINESS : ["business document", "corporate communication", "business report", "professional writing", "executive summary"],
Domain.JOURNALISM : ["news article", "journalism", "press release", "news report", "media content", "reporting"],
Domain.SOCIAL_MEDIA : ["social media post", "casual writing", "online content", "informal text", "social media content"],
Domain.BLOG_PERSONAL : ["personal blog", "personal writing", "lifestyle blog", "personal experience", "opinion piece", "diary entry"],
Domain.LEGAL : ["legal document", "contract", "legal writing", "law", "legal agreement", "legal analysis"],
Domain.MEDICAL : ["medical document", "healthcare", "clinical", "medical report", "health information", "medical research"],
Domain.MARKETING : ["marketing content", "advertising", "brand content", "promotional writing", "sales copy", "marketing material"],
Domain.TUTORIAL : ["tutorial", "how-to guide", "instructional content", "step-by-step guide", "educational guide", "learning material"],
Domain.GENERAL : ["general content", "everyday writing", "common text", "standard writing", "normal text", "general information"],
}
def __init__(self):
self.model_manager = get_model_manager()
self.primary_classifier = None
self.fallback_classifier = None
self.is_initialized = False
def initialize(self) -> bool:
"""
Initialize the domain classifier with zero-shot models
"""
try:
logger.info("Initializing domain classifier...")
# Load primary domain classifier (zero-shot)
self.primary_classifier = self.model_manager.load_model(model_name = "domain_classifier")
# Load fallback classifier
try:
self.fallback_classifier = self.model_manager.load_model(model_name = "domain_classifier_fallback")
logger.info("Fallback classifier loaded successfully")
except Exception as e:
logger.warning(f"Could not load fallback classifier: {repr(e)}")
self.fallback_classifier = None
self.is_initialized = True
logger.success("Domain classifier initialized successfully")
return True
except Exception as e:
logger.error(f"Failed to initialize domain classifier: {repr(e)}")
return False
def classify(self, text: str, top_k: int = 2, min_confidence: float = 0.3) -> DomainPrediction:
"""
Classify text into domain using zero-shot classification
Arguments:
----------
text { str } : Input text
top_k { int } : Number of top domains to consider
min_confidence { float } : Minimum confidence threshold
Returns:
--------
{ DomainPrediction } : DomainPrediction object
"""
if not self.is_initialized:
logger.warning("Domain classifier not initialized, initializing now...")
if not self.initialize():
return self._get_default_prediction()
try:
# First try with primary classifier
primary_result = self._classify_with_model(text = text,
classifier = self.primary_classifier,
model_type = "primary",
)
# If primary result meets confidence threshold, return it
if (primary_result.confidence >= min_confidence):
return primary_result
# If primary is low confidence but we have fallback, try fallback
if self.fallback_classifier:
logger.info("Primary classifier low confidence, trying fallback model...")
fallback_result = self._classify_with_model(text = text,
classifier = self.fallback_classifier,
model_type = "fallback",
)
# Use fallback if it has higher confidence
if fallback_result.confidence > primary_result.confidence:
return fallback_result
# Return primary result even if low confidence
return primary_result
except Exception as e:
logger.error(f"Error in domain classification: {repr(e)}")
# Try fallback classifier if primary failed
if self.fallback_classifier:
try:
logger.info("Trying fallback classifier after primary failure...")
return self._classify_with_model(text = text,
classifier = self.fallback_classifier,
model_type = "fallback",
)
except Exception as fallback_error:
logger.error(f"Fallback classifier also failed: {repr(fallback_error)}")
# Both models failed, return default
return self._get_default_prediction()
def _classify_with_model(self, text: str, classifier, model_type: str) -> DomainPrediction:
"""
Classify using a zero-shot classification model
"""
# Preprocess text
processed_text = self._preprocess_text(text)
# Get all candidate labels
all_labels = list()
label_to_domain = dict()
for domain, labels in self.DOMAIN_LABELS.items():
# Use the first label as the primary one for this domain
primary_label = labels[0]
all_labels.append(primary_label)
label_to_domain[primary_label] = domain
# Perform zero-shot classification
result = classifier(processed_text,
candidate_labels = all_labels,
multi_label = False,
hypothesis_template = "This text is about {}.",
)
# Convert to domain scores
domain_scores = dict()
for label, score in zip(result['labels'], result['scores']):
domain = label_to_domain[label]
domain_key = domain.value
if (domain_key not in domain_scores):
domain_scores[domain_key] = list()
domain_scores[domain_key].append(score)
# Average scores for each domain
avg_domain_scores = {domain: sum(scores) / len(scores) for domain, scores in domain_scores.items()}
# Sort by score
sorted_domains = sorted(avg_domain_scores.items(), key = lambda x: x[1], reverse = True)
# Get primary and secondary domains
primary_domain_str, primary_score = sorted_domains[0]
primary_domain = Domain(primary_domain_str)
secondary_domain = None
secondary_score = 0.0
if ((len(sorted_domains) > 1) and (sorted_domains[1][1] >= 0.1)):
secondary_domain = Domain(sorted_domains[1][0])
secondary_score = sorted_domains[1][1]
# Calculate confidence
confidence = primary_score
# If we have mixed domains with close scores, adjust confidence
if (secondary_domain and (primary_score < 0.7) and (secondary_score > 0.3)):
score_ratio = secondary_score / primary_score
# Secondary is at least 60% of primary
if (score_ratio > 0.6):
# Lower confidence for mixed domains
confidence = (primary_score + secondary_score) / 2 * 0.8
logger.info(f"Mixed domain detected: {primary_domain.value} + {secondary_domain.value}, will use interpolated thresholds")
# If primary score is low and we have a secondary, it's uncertain
elif ((primary_score < 0.5) and secondary_domain):
# Reduce confidence
confidence *= 0.8
logger.info(f"{model_type.capitalize()} model classified domain: {primary_domain.value} (confidence: {confidence:.3f})")
return DomainPrediction(primary_domain = primary_domain,
secondary_domain = secondary_domain,
confidence = confidence,
domain_scores = avg_domain_scores,
)
def _preprocess_text(self, text: str) -> str:
"""
Preprocess text for classification
"""
# Truncate to reasonable length
words = text.split()
if (len(words) > 400):
text = ' '.join(words[:400])
# Clean up text
text = text.strip()
if not text:
return "general content"
return text
def _get_default_prediction(self) -> DomainPrediction:
"""
Get default prediction when classification fails
"""
return DomainPrediction(primary_domain = Domain.GENERAL,
secondary_domain = None,
confidence = 0.5,
domain_scores = {Domain.GENERAL.value: 1.0},
)
def get_adaptive_thresholds(self, domain_prediction: DomainPrediction):
"""
Get adaptive thresholds based on domain prediction
"""
if ((domain_prediction.confidence > 0.7) and (not domain_prediction.secondary_domain)):
return get_threshold_for_domain(domain_prediction.primary_domain)
if domain_prediction.secondary_domain:
primary_score = domain_prediction.domain_scores.get(domain_prediction.primary_domain.value, 0)
secondary_score = domain_prediction.domain_scores.get(domain_prediction.secondary_domain.value, 0)
if (primary_score + secondary_score > 0):
weight1 = primary_score / (primary_score + secondary_score)
else:
weight1 = domain_prediction.confidence
return interpolate_thresholds(domain1 = domain_prediction.primary_domain,
domain2 = domain_prediction.secondary_domain,
weight1 = weight1,
)
if (domain_prediction.confidence < 0.6):
return interpolate_thresholds(domain1 = domain_prediction.primary_domain,
domain2 = Domain.GENERAL,
weight1 = domain_prediction.confidence,
)
return get_threshold_for_domain(domain_prediction.primary_domain)
def cleanup(self):
"""
Clean up resources
"""
self.primary_classifier = None
self.fallback_classifier = None
self.is_initialized = False
# Export
__all__ = ["DomainClassifier",
"DomainPrediction",
]