bawolf's picture
wip
c850c95
raw
history blame
3.59 kB
import torch
import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from src.utils.utils import get_latest_run_dir, get_latest_model_path, get_config
from src.models.model import load_model
from src.dataset.video_utils import create_transform, extract_frames
def setup_model(run_dir=None):
"""Setup model and configuration"""
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Get run directory
if run_dir is None:
run_dir = get_latest_run_dir()
print(f"Using run directory: {run_dir}")
try:
# Load configuration
config = get_config(run_dir)
print(f"Loaded configuration from: {run_dir}")
# Load the model
model_path = get_latest_model_path(run_dir)
print(f"Loading model from: {model_path}")
model = load_model(
config['num_classes'],
model_path,
device,
config['clip_model']
)
model.eval()
return model, config, device
except (ValueError, FileNotFoundError) as e:
print(f"Error loading model: {str(e)}")
exit(1)
def predict(video_path, model, config, device):
"""Predict class for a video using the model"""
transform = create_transform(config, training=False)
try:
frames, success = extract_frames(video_path,
config,
transform)
if not success:
raise ValueError(f"Failed to process video: {video_path}")
frames = frames.to(device)
# Add batch dimension correctly
frames = frames.unsqueeze(0) # Add batch dimension at the start
with torch.no_grad():
try:
outputs = model(frames)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
except Exception as e:
print(f"Error during model forward pass: {str(e)}")
print(f"Model input shape: {frames.shape}")
raise
# Get predictions
avg_probabilities = probabilities[0].cpu().numpy()
predicted_class = np.argmax(avg_probabilities)
# Create a dictionary of class labels and their probabilities
class_probabilities = {
label: float(prob)
for label, prob in zip(config['class_labels'], avg_probabilities)
}
return config['class_labels'][predicted_class], class_probabilities
except Exception as e:
raise ValueError(f"Error processing video: {str(e)}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run inference on a video file')
parser.add_argument('--video', type=str, required=True,
help='Path to the video file')
parser.add_argument('--run-dir', type=str,
help='Path to specific run directory (optional)')
args = parser.parse_args()
# Setup model and config
model, config, device = setup_model(args.run_dir)
try:
predicted_label, class_probabilities = predict(args.video, model, config, device)
print(f"\nPredicted label: {predicted_label}")
print("\nClass probabilities:")
for label, prob in class_probabilities.items():
print(f" {label}: {prob:.4f}")
except ValueError as e:
print(f"Error: {str(e)}")