|
""" |
|
Comprehensive testing suite for rmtariq/multilingual-emotion-classifier |
|
This script provides various testing capabilities for the emotion classification model. |
|
|
|
Usage: |
|
python test_model.py --test-type [quick|comprehensive|interactive|benchmark] |
|
|
|
Author: rmtariq |
|
Repository: https://huggingface.co/rmtariq/multilingual-emotion-classifier |
|
""" |
|
|
|
import argparse |
|
import time |
|
from transformers import pipeline |
|
import torch |
|
|
|
class EmotionModelTester: |
|
"""Comprehensive testing suite for the multilingual emotion classifier""" |
|
|
|
def __init__(self, model_name="rmtariq/multilingual-emotion-classifier"): |
|
self.model_name = model_name |
|
self.classifier = None |
|
self.load_model() |
|
|
|
def load_model(self): |
|
"""Load the emotion classification model""" |
|
print(f"๐ฅ Loading model: {self.model_name}") |
|
try: |
|
self.classifier = pipeline( |
|
"text-classification", |
|
model=self.model_name, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
device = "GPU" if torch.cuda.is_available() else "CPU" |
|
print(f"โ
Model loaded successfully on {device}") |
|
except Exception as e: |
|
print(f"โ Error loading model: {e}") |
|
raise |
|
|
|
def quick_test(self): |
|
"""Quick test with essential examples""" |
|
print("\n๐ QUICK TEST") |
|
print("=" * 50) |
|
|
|
test_cases = [ |
|
|
|
("I am so happy today!", "happy", "๐ฌ๐ง"), |
|
("This makes me really angry!", "anger", "๐ฌ๐ง"), |
|
("I love you so much!", "love", "๐ฌ๐ง"), |
|
("I'm scared of spiders", "fear", "๐ฌ๐ง"), |
|
("This news makes me sad", "sadness", "๐ฌ๐ง"), |
|
("What a surprise!", "surprise", "๐ฌ๐ง"), |
|
|
|
|
|
("Saya sangat gembira!", "happy", "๐ฒ๐พ"), |
|
("Aku marah dengan keadaan ini", "anger", "๐ฒ๐พ"), |
|
("Aku sayang kamu", "love", "๐ฒ๐พ"), |
|
("Saya takut dengan ini", "fear", "๐ฒ๐พ"), |
|
|
|
|
|
("Ini adalah hari jadi terbaik", "happy", "๐ฒ๐พ"), |
|
("Terbaik!", "happy", "๐ฒ๐พ"), |
|
("Ini adalah hari yang baik", "happy", "๐ฒ๐พ") |
|
] |
|
|
|
correct = 0 |
|
total = len(test_cases) |
|
|
|
for i, (text, expected, flag) in enumerate(test_cases, 1): |
|
result = self.classifier(text) |
|
predicted = result[0]['label'].lower() |
|
confidence = result[0]['score'] |
|
|
|
is_correct = predicted == expected |
|
if is_correct: |
|
correct += 1 |
|
|
|
status = "โ
" if is_correct else "โ" |
|
print(f"{i:2d}. {status} {flag} '{text[:40]}...'") |
|
print(f" โ {predicted} ({confidence:.1%}) [Expected: {expected}]") |
|
|
|
accuracy = correct / total |
|
print(f"\n๐ Quick Test Results: {accuracy:.1%} ({correct}/{total})") |
|
|
|
if accuracy >= 0.9: |
|
print("๐ EXCELLENT! Model performing at high level!") |
|
elif accuracy >= 0.8: |
|
print("๐ GOOD! Model performing well!") |
|
else: |
|
print("โ ๏ธ NEEDS ATTENTION. Some issues detected.") |
|
|
|
return accuracy |
|
|
|
def comprehensive_test(self): |
|
"""Comprehensive test covering all aspects""" |
|
print("\n๐ฌ COMPREHENSIVE TEST") |
|
print("=" * 50) |
|
|
|
|
|
test_categories = { |
|
"English Basic": [ |
|
("I feel fantastic today!", "happy"), |
|
("I'm furious about this!", "anger"), |
|
("I adore this place!", "love"), |
|
("I'm terrified of heights", "fear"), |
|
("I'm heartbroken", "sadness"), |
|
("I can't believe it!", "surprise") |
|
], |
|
"Malay Basic": [ |
|
("Gembira sangat hari ini", "happy"), |
|
("Marah betul dengan dia", "anger"), |
|
("Sayang sangat kat kamu", "love"), |
|
("Takut gila dengan benda tu", "fear"), |
|
("Sedih betul dengar berita", "sadness"), |
|
("Terkejut dengan kejadian", "surprise") |
|
], |
|
"Malay Fixed Issues": [ |
|
("Ini adalah hari jadi terbaik", "happy"), |
|
("Hari jadi terbaik saya", "happy"), |
|
("Terbaik!", "happy"), |
|
("Hari yang baik", "happy"), |
|
("Pengalaman terbaik", "happy"), |
|
("Masa terbaik", "happy") |
|
], |
|
"Edge Cases": [ |
|
("Happy birthday!", "happy"), |
|
("Best day ever!", "happy"), |
|
("Good news!", "happy"), |
|
("Selamat hari jadi", "happy"), |
|
("Berita baik", "happy"), |
|
("Hasil terbaik", "happy") |
|
] |
|
} |
|
|
|
overall_correct = 0 |
|
overall_total = 0 |
|
|
|
for category, cases in test_categories.items(): |
|
print(f"\n๐ {category}:") |
|
print("-" * 30) |
|
|
|
category_correct = 0 |
|
for text, expected in cases: |
|
result = self.classifier(text) |
|
predicted = result[0]['label'].lower() |
|
confidence = result[0]['score'] |
|
|
|
is_correct = predicted == expected |
|
if is_correct: |
|
category_correct += 1 |
|
overall_correct += 1 |
|
|
|
overall_total += 1 |
|
|
|
status = "โ
" if is_correct else "โ" |
|
print(f" {status} '{text[:35]}...' โ {predicted} ({confidence:.1%})") |
|
|
|
category_accuracy = category_correct / len(cases) |
|
print(f" ๐ {category} Accuracy: {category_accuracy:.1%}") |
|
|
|
overall_accuracy = overall_correct / overall_total |
|
print(f"\n๐ COMPREHENSIVE TEST RESULTS:") |
|
print(f"โ
Overall Accuracy: {overall_accuracy:.1%} ({overall_correct}/{overall_total})") |
|
|
|
return overall_accuracy |
|
|
|
def interactive_test(self): |
|
"""Interactive testing mode""" |
|
print("\n๐ฎ INTERACTIVE TEST MODE") |
|
print("=" * 50) |
|
print("Enter text to classify emotions (type 'quit' to exit)") |
|
print("Supported emotions: anger, fear, happy, love, sadness, surprise") |
|
print() |
|
|
|
while True: |
|
try: |
|
text = input("๐ฌ Your text: ").strip() |
|
|
|
if text.lower() in ['quit', 'exit', 'q']: |
|
print("๐ Goodbye!") |
|
break |
|
|
|
if not text: |
|
continue |
|
|
|
result = self.classifier(text) |
|
predicted = result[0]['label'].lower() |
|
confidence = result[0]['score'] |
|
|
|
|
|
emotion_emojis = { |
|
'anger': '๐ ', 'fear': '๐จ', 'happy': '๐', |
|
'love': 'โค๏ธ', 'sadness': '๐ข', 'surprise': '๐ฒ' |
|
} |
|
|
|
emoji = emotion_emojis.get(predicted, '๐ค') |
|
confidence_level = "๐ช High" if confidence > 0.9 else "๐ Good" if confidence > 0.7 else "โ ๏ธ Low" |
|
|
|
print(f"๐ญ Result: {emoji} {predicted}") |
|
print(f"๐ Confidence: {confidence:.1%}") |
|
print(f"๐ช {confidence_level} confidence!") |
|
print() |
|
|
|
except KeyboardInterrupt: |
|
print("\n๐ Goodbye!") |
|
break |
|
except Exception as e: |
|
print(f"โ Error: {e}") |
|
|
|
def benchmark_test(self): |
|
"""Performance benchmark test""" |
|
print("\nโก BENCHMARK TEST") |
|
print("=" * 50) |
|
|
|
|
|
benchmark_texts = [ |
|
"I am so happy today!", |
|
"This makes me angry!", |
|
"I love this!", |
|
"I'm scared!", |
|
"This is sad news", |
|
"What a surprise!", |
|
"Saya gembira!", |
|
"Aku marah!", |
|
"Sayang betul!", |
|
"Takut sangat!" |
|
] * 10 |
|
|
|
print(f"๐ Running {len(benchmark_texts)} predictions...") |
|
|
|
start_time = time.time() |
|
|
|
for text in benchmark_texts: |
|
_ = self.classifier(text) |
|
|
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
avg_time = total_time / len(benchmark_texts) |
|
predictions_per_second = len(benchmark_texts) / total_time |
|
|
|
print(f"๐ BENCHMARK RESULTS:") |
|
print(f"โฑ๏ธ Total time: {total_time:.2f} seconds") |
|
print(f"โก Average per prediction: {avg_time*1000:.1f} ms") |
|
print(f"๐ Predictions per second: {predictions_per_second:.1f}") |
|
|
|
if predictions_per_second > 10: |
|
print("๐ EXCELLENT! Very fast performance!") |
|
elif predictions_per_second > 5: |
|
print("๐ GOOD! Acceptable performance!") |
|
else: |
|
print("โ ๏ธ SLOW. Consider optimization.") |
|
|
|
return predictions_per_second |
|
|
|
def main(): |
|
"""Main testing function""" |
|
parser = argparse.ArgumentParser(description="Test the multilingual emotion classifier") |
|
parser.add_argument( |
|
"--test-type", |
|
choices=["quick", "comprehensive", "interactive", "benchmark", "all"], |
|
default="quick", |
|
help="Type of test to run" |
|
) |
|
parser.add_argument( |
|
"--model", |
|
default="rmtariq/multilingual-emotion-classifier", |
|
help="Model name or path" |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
print("๐ญ MULTILINGUAL EMOTION CLASSIFIER TESTING SUITE") |
|
print("=" * 60) |
|
print(f"Model: {args.model}") |
|
print(f"Test Type: {args.test_type}") |
|
|
|
try: |
|
tester = EmotionModelTester(args.model) |
|
|
|
if args.test_type == "quick": |
|
tester.quick_test() |
|
elif args.test_type == "comprehensive": |
|
tester.comprehensive_test() |
|
elif args.test_type == "interactive": |
|
tester.interactive_test() |
|
elif args.test_type == "benchmark": |
|
tester.benchmark_test() |
|
elif args.test_type == "all": |
|
print("๐ Running all tests...") |
|
tester.quick_test() |
|
tester.comprehensive_test() |
|
tester.benchmark_test() |
|
print("\n๐ฎ Starting interactive mode...") |
|
tester.interactive_test() |
|
|
|
except Exception as e: |
|
print(f"โ Testing failed: {e}") |
|
return 1 |
|
|
|
return 0 |
|
|
|
if __name__ == "__main__": |
|
exit(main()) |
|
|