Final_Assignment / final_classification_test.py
tonthatthienvu's picture
Clean repository without binary files
37cadfb
#!/usr/bin/env python3
"""
Final test for YouTube question classification and tool selection
"""
from question_classifier import QuestionClassifier
def test_classification():
"""Test that our classification improvements for YouTube questions are working"""
# Initialize classifier
classifier = QuestionClassifier()
# Test cases
test_cases = [
{
'question': 'In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species?',
'expected_agent': 'multimedia',
'expected_tool': 'analyze_youtube_video'
},
{
'question': 'Tell me about the video at youtu.be/dQw4w9WgXcQ',
'expected_agent': 'multimedia',
'expected_tool': 'analyze_youtube_video'
},
{
'question': 'What does Teal\'c say in the YouTube video youtube.com/watch?v=XYZ123?',
'expected_agent': 'multimedia',
'expected_tool': 'analyze_youtube_video'
},
{
'question': 'How many birds appear in this image?',
'expected_agent': 'multimedia',
'expected_tool': 'analyze_image_with_gemini'
},
{
'question': 'When was the first Star Wars movie released?',
'expected_agent': 'research',
'expected_tool': None
}
]
print("πŸ§ͺ Testing Question Classification for YouTube Questions")
print("=" * 70)
passed = 0
for i, case in enumerate(test_cases):
print(f"\nTest {i+1}: {case['question'][:80]}...")
# Classify the question
classification = classifier.classify_question(case['question'])
# Check primary agent type
agent_correct = classification['primary_agent'] == case['expected_agent']
# Check if expected tool is in tools list
expected_tool = case['expected_tool']
if expected_tool:
tool_correct = expected_tool in classification.get('tools_needed', [])
else:
# If no specific tool expected, just make sure analyze_youtube_video isn't
# incorrectly selected for non-YouTube questions
tool_correct = 'analyze_youtube_video' not in classification.get('tools_needed', []) or 'youtube' in case['question'].lower()
# Print results
print(f"Expected agent: {case['expected_agent']}")
print(f"Actual agent: {classification['primary_agent']}")
print(f"Agent match: {'βœ…' if agent_correct else '❌'}")
print(f"Expected tool: {case['expected_tool']}")
print(f"Selected tools: {classification.get('tools_needed', [])}")
print(f"Tool match: {'βœ…' if tool_correct else '❌'}")
# Check which tools were selected first
tools = classification.get('tools_needed', [])
if tools and 'youtube' in case['question'].lower():
if tools[0] == 'analyze_youtube_video':
print("βœ… analyze_youtube_video correctly prioritized for YouTube question")
else:
print("❌ analyze_youtube_video not prioritized for YouTube question")
# Print overall result
if agent_correct and tool_correct:
passed += 1
print("βœ… TEST PASSED")
else:
print("❌ TEST FAILED")
# Print summary
print("\n" + "=" * 70)
print(f"Final result: {passed}/{len(test_cases)} tests passed")
if passed == len(test_cases):
print("πŸŽ‰ All tests passed! The classifier is working correctly.")
else:
print("⚠️ Some tests failed. Further improvements needed.")
if __name__ == "__main__":
test_classification()