| |
| """ |
| Test script for RetNet Explicitness Classifier |
| Usage: python test_model.py |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import json |
| from transformers import AutoTokenizer |
| from model import ProductionRetNet |
| import time |
|
|
| class RetNetExplicitnessClassifier: |
| """Easy-to-use interface for RetNet explicitness classification""" |
| |
| def __init__(self, model_path=None, device='auto'): |
| """Initialize the classifier |
| |
| Args: |
| model_path: Path to the trained model file |
| device: Device to run on ('auto', 'cpu', 'cuda', 'mps') |
| """ |
| |
| with open('config.json', 'r') as f: |
| self.config = json.load(f) |
| |
| |
| if model_path is None: |
| model_path = self.config.get('model_file', 'model.safetensors') |
| |
| |
| if device == 'auto': |
| if torch.cuda.is_available(): |
| self.device = 'cuda' |
| elif torch.backends.mps.is_available(): |
| self.device = 'mps' |
| else: |
| self.device = 'cpu' |
| else: |
| self.device = device |
| |
| print(f"๐ Using device: {self.device}") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained('gpt2') |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.model = self._load_model(model_path) |
| self.labels = self.config['labels'] |
| |
| def _load_model(self, model_path): |
| """Load the RetNet model""" |
| model = ProductionRetNet( |
| vocab_size=self.config['vocab_size'], |
| dim=self.config['model_dim'], |
| num_layers=self.config['num_layers'], |
| num_heads=self.config['num_heads'], |
| num_classes=self.config['num_classes'], |
| max_length=self.config['max_length'] |
| ) |
| |
| |
| from safetensors.torch import load_file |
| state_dict = load_file(model_path, device=self.device) |
| model.load_state_dict(state_dict) |
| |
| model.to(self.device) |
| model.eval() |
| |
| return model |
| |
| def classify(self, text): |
| """Classify a single text |
| |
| Args: |
| text: Input text to classify |
| |
| Returns: |
| dict: Classification results with label, confidence, and all probabilities |
| """ |
| |
| inputs = self.tokenizer( |
| text, |
| truncation=True, |
| padding=True, |
| max_length=self.config['max_length'], |
| return_tensors='pt' |
| ) |
| |
| input_ids = inputs['input_ids'].to(self.device) |
| attention_mask = inputs['attention_mask'].to(self.device) |
| |
| |
| with torch.no_grad(): |
| logits = self.model(input_ids, attention_mask) |
| probabilities = F.softmax(logits, dim=-1) |
| |
| |
| probs = probabilities[0].cpu().numpy() |
| pred_id = int(probs.argmax()) |
| confidence = float(probs[pred_id]) |
| |
| return { |
| 'text': text, |
| 'predicted_class': self.labels[pred_id], |
| 'confidence': confidence, |
| 'probabilities': { |
| label: float(probs[i]) for i, label in enumerate(self.labels) |
| } |
| } |
| |
| def classify_batch(self, texts): |
| """Classify multiple texts efficiently |
| |
| Args: |
| texts: List of input texts |
| |
| Returns: |
| list: List of classification results |
| """ |
| results = [] |
| batch_size = 32 |
| |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i + batch_size] |
| |
| |
| inputs = self.tokenizer( |
| batch, |
| truncation=True, |
| padding=True, |
| max_length=self.config['max_length'], |
| return_tensors='pt' |
| ) |
| |
| input_ids = inputs['input_ids'].to(self.device) |
| attention_mask = inputs['attention_mask'].to(self.device) |
| |
| |
| with torch.no_grad(): |
| logits = self.model(input_ids, attention_mask) |
| probabilities = F.softmax(logits, dim=-1) |
| |
| |
| for j, text in enumerate(batch): |
| probs = probabilities[j].cpu().numpy() |
| pred_id = int(probs.argmax()) |
| confidence = float(probs[pred_id]) |
| |
| results.append({ |
| 'text': text, |
| 'predicted_class': self.labels[pred_id], |
| 'confidence': confidence, |
| 'probabilities': { |
| label: float(probs[k]) for k, label in enumerate(self.labels) |
| } |
| }) |
| |
| return results |
|
|
| def main(): |
| """Test the RetNet classifier with example texts""" |
| print("๐งช Testing RetNet Explicitness Classifier") |
| print("=" * 60) |
| |
| |
| classifier = RetNetExplicitnessClassifier() |
| |
| |
| test_texts = [ |
| |
| "The morning sun cast long shadows across the peaceful meadow as birds sang in the trees.", |
| |
| |
| "She felt a spark of attraction as their eyes met across the crowded room.", |
| |
| |
| "The romance novel described their passionate night together in tasteful detail.", |
| |
| |
| "His hands explored every inch of her naked body as she moaned with pleasure.", |
| |
| |
| "The killer slowly twisted the knife deeper into his victim's chest.", |
| |
| |
| "What the fuck is wrong with you, you goddamn idiot?", |
| |
| |
| "Warning: This content contains explicit sexual material and violence." |
| ] |
| |
| print(f"๐ Testing {len(test_texts)} example texts...\n") |
| |
| |
| print("๐ Single Text Classification:") |
| print("-" * 40) |
| |
| for i, text in enumerate(test_texts): |
| result = classifier.classify(text) |
| print(f"\n{i+1}. Text: {result['text']}") |
| print(f" Prediction: {result['predicted_class']}") |
| print(f" Confidence: {result['confidence']:.3f}") |
| |
| |
| print(f"\nโก Batch Classification Performance:") |
| print("-" * 40) |
| |
| start_time = time.time() |
| batch_results = classifier.classify_batch(test_texts) |
| elapsed_time = time.time() - start_time |
| |
| texts_per_sec = len(test_texts) / elapsed_time |
| |
| print(f"๐ Processed {len(test_texts)} texts in {elapsed_time:.3f}s") |
| print(f"๐ Speed: {texts_per_sec:.1f} texts/second") |
| |
| |
| predictions = [r['predicted_class'] for r in batch_results] |
| pred_counts = {} |
| for pred in predictions: |
| pred_counts[pred] = pred_counts.get(pred, 0) + 1 |
| |
| print(f"\n๐ Prediction Distribution:") |
| for label, count in sorted(pred_counts.items()): |
| print(f" {label}: {count}") |
| |
| |
| print(f"\n๐ค Model Information:") |
| print(f" Parameters: {classifier.config['performance']['parameters']:,}") |
| print(f" Holdout F1: {classifier.config['performance']['holdout_macro_f1']:.3f}") |
| print(f" Holdout Accuracy: {classifier.config['performance']['holdout_accuracy']:.3f}") |
| print(f" Training Time: {classifier.config['training']['training_time_hours']:.1f} hours") |
| |
| print(f"\nโ
RetNet classifier test completed!") |
|
|
| if __name__ == "__main__": |
| main() |