Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Video Analysis Tool for GAIA Agent | |
| Provides video frame extraction and visual analysis capabilities for YouTube videos. | |
| Specifically designed to handle questions requiring visual analysis (e.g., counting objects). | |
| """ | |
| import os | |
| import logging | |
| import tempfile | |
| import subprocess | |
| from typing import Dict, Any, List, Optional, Union | |
| from pathlib import Path | |
| import requests | |
| import re | |
| try: | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| CV2_AVAILABLE = True | |
| except ImportError: | |
| cv2 = None | |
| np = None | |
| Image = None | |
| CV2_AVAILABLE = False | |
| try: | |
| import yt_dlp | |
| YT_DLP_AVAILABLE = True | |
| except ImportError: | |
| YT_DLP_AVAILABLE = False | |
| # Import existing multimodal tools | |
| try: | |
| from agents.mistral_multimodal_agent import OpenSourceMultimodalTools | |
| MULTIMODAL_AVAILABLE = True | |
| except ImportError: | |
| MULTIMODAL_AVAILABLE = False | |
| logger = logging.getLogger(__name__) | |
| class VideoAnalysisTool: | |
| """ | |
| Video Analysis Tool for extracting frames and performing visual analysis. | |
| Capabilities: | |
| - Extract frames from YouTube videos | |
| - Analyze frames using multimodal image analysis | |
| - Count objects across multiple frames | |
| - Handle visual questions that require frame-by-frame analysis | |
| """ | |
| def __init__(self): | |
| """Initialize the video analysis tool.""" | |
| logger.info("🎬 Initializing Video Analysis Tool...") | |
| # Check dependencies | |
| self.cv2_available = CV2_AVAILABLE | |
| self.yt_dlp_available = YT_DLP_AVAILABLE | |
| self.multimodal_available = MULTIMODAL_AVAILABLE | |
| # Initialize multimodal tools if available | |
| self.multimodal_tools = None | |
| if self.multimodal_available: | |
| try: | |
| self.multimodal_tools = OpenSourceMultimodalTools() | |
| logger.info("✅ Multimodal tools initialized") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Multimodal tools initialization failed: {e}") | |
| self.multimodal_available = False | |
| # Log capabilities | |
| capabilities = [] | |
| if self.cv2_available: | |
| capabilities.append("Frame extraction (OpenCV)") | |
| if self.yt_dlp_available: | |
| capabilities.append("YouTube download (yt-dlp)") | |
| if self.multimodal_available: | |
| capabilities.append("Image analysis (Multimodal)") | |
| logger.info(f"📊 Available capabilities: {', '.join(capabilities)}") | |
| if not any([self.cv2_available, self.yt_dlp_available]): | |
| logger.warning("⚠️ Limited functionality - install opencv-python and yt-dlp for full capabilities") | |
| def extract_video_id(self, youtube_url: str) -> Optional[str]: | |
| """Extract video ID from YouTube URL.""" | |
| patterns = [ | |
| r'(?:youtube\.com/watch\?v=|youtu\.be/|youtube\.com/embed/)([^&\n?#]+)', | |
| r'youtube\.com/watch\?.*v=([^&\n?#]+)' | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, youtube_url) | |
| if match: | |
| return match.group(1) | |
| return None | |
| def download_video(self, youtube_url: str, output_dir: str) -> Optional[str]: | |
| """ | |
| Download YouTube video for frame extraction. | |
| Args: | |
| youtube_url: YouTube video URL | |
| output_dir: Directory to save the video | |
| Returns: | |
| Path to downloaded video file or None if failed | |
| """ | |
| if not self.yt_dlp_available: | |
| logger.error("❌ yt-dlp not available for video download") | |
| return None | |
| try: | |
| video_id = self.extract_video_id(youtube_url) | |
| if not video_id: | |
| logger.error(f"❌ Could not extract video ID from URL: {youtube_url}") | |
| return None | |
| output_path = os.path.join(output_dir, f"{video_id}.%(ext)s") | |
| ydl_opts = { | |
| 'format': 'best[height<=720]', # Limit quality for faster processing | |
| 'outtmpl': output_path, | |
| 'quiet': True, | |
| 'no_warnings': True, | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([youtube_url]) | |
| # Find the downloaded file | |
| for file in os.listdir(output_dir): | |
| if file.startswith(video_id): | |
| downloaded_path = os.path.join(output_dir, file) | |
| logger.info(f"✅ Video downloaded: {downloaded_path}") | |
| return downloaded_path | |
| logger.error("❌ Downloaded video file not found") | |
| return None | |
| except Exception as e: | |
| logger.error(f"❌ Video download failed: {e}") | |
| return None | |
| def extract_frames(self, video_path: str, max_frames: int = 10, interval_seconds: float = 5.0) -> List[Any]: | |
| """ | |
| Extract frames from video at regular intervals. | |
| Args: | |
| video_path: Path to video file | |
| max_frames: Maximum number of frames to extract | |
| interval_seconds: Interval between frames in seconds | |
| Returns: | |
| List of frame arrays | |
| """ | |
| if not self.cv2_available: | |
| logger.error("❌ OpenCV not available for frame extraction") | |
| return [] | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| logger.error(f"❌ Could not open video: {video_path}") | |
| return [] | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| duration = total_frames / fps if fps > 0 else 0 | |
| logger.info(f"📹 Video info: {duration:.1f}s, {fps:.1f} FPS, {total_frames} frames") | |
| frames = [] | |
| frame_interval = int(fps * interval_seconds) if fps > 0 else 30 | |
| frame_count = 0 | |
| extracted_count = 0 | |
| while extracted_count < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_interval == 0: | |
| # Convert BGR to RGB for PIL compatibility | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(frame_rgb) | |
| extracted_count += 1 | |
| logger.info(f"📸 Extracted frame {extracted_count} at {frame_count/fps:.1f}s") | |
| frame_count += 1 | |
| cap.release() | |
| logger.info(f"✅ Extracted {len(frames)} frames from video") | |
| return frames | |
| except Exception as e: | |
| logger.error(f"❌ Frame extraction failed: {e}") | |
| return [] | |
| def analyze_frame(self, frame: Any, question: str) -> str: | |
| """ | |
| Analyze a single frame using multimodal image analysis. | |
| Args: | |
| frame: Frame array (RGB format) | |
| question: Question about the frame | |
| Returns: | |
| Analysis result | |
| """ | |
| if not self.multimodal_available or not self.multimodal_tools: | |
| return "Error: Multimodal analysis not available" | |
| try: | |
| # Convert numpy array to PIL Image | |
| pil_image = Image.fromarray(frame) | |
| # Use multimodal tools for analysis | |
| result = self.multimodal_tools.analyze_image(pil_image, question) | |
| return result | |
| except Exception as e: | |
| logger.error(f"❌ Frame analysis failed: {e}") | |
| return f"Error analyzing frame: {e}" | |
| def analyze_video_for_objects(self, youtube_url: str, question: str, max_frames: int = 10) -> str: | |
| """ | |
| Analyze YouTube video for object counting or visual questions. | |
| Args: | |
| youtube_url: YouTube video URL | |
| question: Question about the video (e.g., "count bird species") | |
| max_frames: Maximum frames to analyze | |
| Returns: | |
| Analysis result with object counts or visual information | |
| """ | |
| logger.info(f"🎬 Starting video analysis for: {youtube_url}") | |
| logger.info(f"❓ Question: {question}") | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Step 1: Download video | |
| video_path = self.download_video(youtube_url, temp_dir) | |
| if not video_path: | |
| return "Error: Could not download video for analysis" | |
| # Step 2: Extract frames | |
| frames = self.extract_frames(video_path, max_frames=max_frames) | |
| if not frames: | |
| return "Error: Could not extract frames from video" | |
| # Step 3: Analyze each frame | |
| frame_analyses = [] | |
| for i, frame in enumerate(frames): | |
| logger.info(f"🔍 Analyzing frame {i+1}/{len(frames)}") | |
| analysis = self.analyze_frame(frame, question) | |
| frame_analyses.append({ | |
| 'frame_number': i + 1, | |
| 'timestamp': f"{i * 5.0:.1f}s", # Assuming 5s intervals | |
| 'analysis': analysis | |
| }) | |
| # Step 4: Synthesize results | |
| return self._synthesize_video_analysis(frame_analyses, question) | |
| def _synthesize_video_analysis(self, frame_analyses: List[Dict], question: str) -> str: | |
| """ | |
| Synthesize analysis results from multiple frames. | |
| Args: | |
| frame_analyses: List of frame analysis results | |
| question: Original question | |
| Returns: | |
| Synthesized answer | |
| """ | |
| if not frame_analyses: | |
| return "No frames were analyzed" | |
| # For counting questions, extract numbers and find maximum | |
| if any(word in question.lower() for word in ['count', 'number', 'how many', 'species']): | |
| numbers_found = [] | |
| for frame_analysis in frame_analyses: | |
| analysis_text = frame_analysis['analysis'].lower() | |
| # Extract numbers from analysis | |
| import re | |
| numbers = re.findall(r'\b(\d+)\b', analysis_text) | |
| for num_str in numbers: | |
| try: | |
| num = int(num_str) | |
| if 1 <= num <= 20: # Reasonable range for object counting | |
| numbers_found.append(num) | |
| except ValueError: | |
| continue | |
| if numbers_found: | |
| max_count = max(numbers_found) | |
| logger.info(f"🔢 Found counts across frames: {numbers_found}, max: {max_count}") | |
| # Build detailed response | |
| response_parts = [ | |
| f"Analysis of {len(frame_analyses)} video frames:", | |
| "" | |
| ] | |
| for frame_analysis in frame_analyses: | |
| response_parts.append( | |
| f"Frame {frame_analysis['frame_number']} ({frame_analysis['timestamp']}): " | |
| f"{frame_analysis['analysis'][:100]}..." | |
| ) | |
| response_parts.extend([ | |
| "", | |
| f"Maximum count detected: {max_count}", | |
| f"Answer: {max_count}" | |
| ]) | |
| return "\n".join(response_parts) | |
| # For non-counting questions, provide comprehensive analysis | |
| response_parts = [ | |
| f"Video analysis results ({len(frame_analyses)} frames):", | |
| "" | |
| ] | |
| for frame_analysis in frame_analyses: | |
| response_parts.append( | |
| f"Frame {frame_analysis['frame_number']} ({frame_analysis['timestamp']}): " | |
| f"{frame_analysis['analysis']}" | |
| ) | |
| return "\n".join(response_parts) | |
| def get_capabilities(self) -> Dict[str, bool]: | |
| """Get current tool capabilities.""" | |
| return { | |
| 'video_download': self.yt_dlp_available, | |
| 'frame_extraction': self.cv2_available, | |
| 'image_analysis': self.multimodal_available, | |
| 'full_video_analysis': all([ | |
| self.yt_dlp_available, | |
| self.cv2_available, | |
| self.multimodal_available | |
| ]) | |
| } | |
| # AGNO Tool Integration | |
| def analyze_youtube_video(url: str, question: str) -> str: | |
| """ | |
| AGNO-compatible function for YouTube video analysis. | |
| Args: | |
| url: YouTube video URL | |
| question: Question about the video | |
| Returns: | |
| Analysis result | |
| """ | |
| tool = VideoAnalysisTool() | |
| return tool.analyze_video_for_objects(url, question) | |
| if __name__ == "__main__": | |
| # Test the video analysis tool | |
| tool = VideoAnalysisTool() | |
| print("🎬 Video Analysis Tool Test") | |
| print("=" * 50) | |
| print(f"Capabilities: {tool.get_capabilities()}") | |
| # Test with the bird species question | |
| test_url = "https://www.youtube.com/watch?v=LivXCYZAYYM" | |
| test_question = "What is the highest number of bird species to be on camera simultaneously?" | |
| print(f"\n🧪 Testing with:") | |
| print(f"URL: {test_url}") | |
| print(f"Question: {test_question}") | |
| if tool.get_capabilities()['full_video_analysis']: | |
| result = tool.analyze_video_for_objects(test_url, test_question, max_frames=5) | |
| print(f"\n📊 Result:\n{result}") | |
| else: | |
| print("\n⚠️ Cannot run full test - missing dependencies") | |
| print("Install: pip install opencv-python yt-dlp") |