Task1_1 / 2_2_2_2_2.py
hiren05's picture
Upload 2_2_2_2_2.py
a07ef96 verified
# -*- coding: utf-8 -*-
"""2.2.2.2.2.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1igY4MKIJJTPHgEkdLFI_T5H6sLUoTaLr
"""
#heat map video and metrics
"""## CODE"""
pip install torchmetrics lpips
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.fid import FrechetInceptionDistance
import lpips
import os
import random
import shutil
from huggingface_hub import HfApi, hf_hub_download
import tarfile
import json
import cv2
from tqdm import tqdm
def download_sequential_data(repo_id="Amar-S/MOVi-MC-AC", sample_ratio=0.01, base_dir="/content/data"):
"""
Download data while preserving video sequences
"""
api = HfApi()
# Create directories
os.makedirs(f"{base_dir}/train", exist_ok=True)
os.makedirs(f"{base_dir}/test", exist_ok=True)
# List all files in the repo
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
# Separate train and test archives (each archive contains a complete scene sequence)
#train_files = [f for f in files if f.startswith("train/") and f.endswith(".tar.gz")]
test_files = [f for f in files if f.startswith("test/") and f.endswith(".tar.gz")]
#print(f"Found {len(train_files)} train archives and {len(test_files)} test archives.")
# Sample complete archives (not individual files) to preserve sequences
#subset_train = random.sample(train_files, max(1, int(len(train_files) * sample_ratio)))
subset_test = random.sample(test_files, max(1, int(len(test_files) * sample_ratio)))
#print(f"Downloading {len(subset_train)} train archives and {len(subset_test)} test archives...")
# Download training archives
# for file in subset_train:
# print(f"Downloading {file}...")
# out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
# dest_path = f"{base_dir}/train/{os.path.basename(file)}"
# shutil.copyfile(out_path, dest_path)
# Download test archives
for file in subset_test:
print(f"Downloading {file}...")
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file)
dest_path = f"{base_dir}/test/{os.path.basename(file)}"
shutil.copyfile(out_path, dest_path)
# Extract all archives
extract_archives(f"{base_dir}/train")
extract_archives(f"{base_dir}/test")
print("Download and extraction complete!")
def extract_archives(directory):
"""Extract all tar.gz files in a directory"""
for file in os.listdir(directory):
if file.endswith(".tar.gz"):
filepath = os.path.join(directory, file)
print(f"Extracting {filepath}...")
with tarfile.open(filepath, 'r:gz') as tar:
tar.extractall(path=directory)
# Remove the archive after extraction
os.remove(filepath)
download_sequential_data()
#extract_archives('/content/data/train')
extract_archives('/content/data/test')
def extract_archives(directory):
"""Extract all tar.gz files in a directory"""
for file in os.listdir(directory):
if file.endswith(".tar.gz"):
filepath = os.path.join(directory, file)
print(f"Extracting {filepath}...")
with tarfile.open(filepath, 'r:gz') as tar:
print(filepath)
tar.extractall(path=directory)
# Remove the archive after extraction
os.remove(filepath)
#extract_archives('/content/data/train')
extract_archives('/content/data/test')
class VideoAmodalDataset(Dataset):
def __init__(self, root_dir, split='train', seq_len=8, img_size=(256,256),
max_scenes=4, samples_per_scene=3, max_samples=None):
self.root_dir = Path(root_dir)
self.split = split
self.seq_len = seq_len
self.img_size = img_size
self.max_scenes = max_scenes
self.samples_per_scene = samples_per_scene
self.samples = self._build_sample_index(max_samples)
self.transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
])
def _build_sample_index(self, max_samples):
samples = []
scene_paths = sorted((self.root_dir / self.split).glob('scene_*'))[:self.max_scenes]
for scene_path in scene_paths:
camera_paths = sorted(scene_path.glob('camera_*'))
for camera_path in camera_paths:
obj_paths = sorted(camera_path.glob('obj_*'))
selected_objs = random.sample(obj_paths, min(self.samples_per_scene, len(obj_paths)))
for obj_path in selected_objs:
rgba_files = sorted(camera_path.glob('rgba_*.png'))
frame_ids = [int(p.stem.split('_')[1]) for p in rgba_files]
# Create non-overlapping sequences
for i in range(0, len(frame_ids) - self.seq_len + 1, self.seq_len):
samples.append({
'scene': scene_path.name,
'camera': camera_path.name,
'obj_folder': obj_path.name,
'frame_ids': frame_ids[i:i+self.seq_len],
'obj_id': int(obj_path.name.split('_')[1])
})
if max_samples and len(samples) >= max_samples:
return samples
return samples
def __getitem__(self, idx):
sample = self.samples[idx]
base_path = self.root_dir / self.split / sample['scene'] / sample['camera']
obj_path = base_path / sample['obj_folder']
rgb_frames = []
modal_mask_frames = []
amodal_mask_frames = []
amodal_rgb_frames = []
for fid in sample['frame_ids']:
fid_str = f"{fid:05d}"
try:
# Load scene RGB
rgb = Image.open(base_path / f'rgba_{fid_str}.png').convert('RGB')
rgb = self.transform(rgb)
# Load scene segmentation to compute modal mask
seg_map = np.array(Image.open(base_path / f'segmentation_{fid_str}.png'))
modal_mask_np = (seg_map == sample['obj_id']).astype(np.uint8) * 255
modal_mask = Image.fromarray(modal_mask_np, mode='L')
modal_mask = self.transform(modal_mask)
# Load amodal mask
amodal_mask = Image.open(obj_path / f'segmentation_{fid_str}.png').convert('L')
amodal_mask = self.transform(amodal_mask)
# Load target amodal RGB
amodal_rgb = Image.open(obj_path / f'rgba_{fid_str}.png').convert('RGB')
amodal_rgb = self.transform(amodal_rgb)
rgb_frames.append(rgb)
modal_mask_frames.append(modal_mask)
amodal_mask_frames.append(amodal_mask)
amodal_rgb_frames.append(amodal_rgb)
except Exception as e:
print(f"Error loading {base_path}/rgba_{fid_str}.png: {e}")
# Return empty tensors if loading fails
empty_rgb = torch.zeros(3, self.img_size[0], self.img_size[1])
empty_mask = torch.zeros(1, self.img_size[0], self.img_size[1])
return {
'rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
'modal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
'amodal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
'amodal_rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1),
'scene': sample['scene'],
'camera': sample['camera'],
'object_id': sample['obj_id']
}
return {
'rgb_sequence': torch.stack(rgb_frames), # Scene RGB
'modal_masks': torch.stack(modal_mask_frames), # Modal masks (visible parts)
'amodal_masks': torch.stack(amodal_mask_frames), # Amodal masks (complete shape)
'amodal_rgb_sequence': torch.stack(amodal_rgb_frames), # Target: complete object RGB
'scene': sample['scene'],
'camera': sample['camera'],
'object_id': sample['obj_id']
}
def __len__(self):
return len(self.samples)
import wandb
wandb.login()
# Add these imports to your existing imports
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import torch.nn.functional as F
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torchvision.models import inception_v3
from torchvision.transforms import Resize, Normalize
import lpips
# Add this class for computing metrics
class VideoAmodalMetrics:
"""Compute various metrics for video amodal completion"""
def __init__(self, device='cuda'):
self.device = device
# Initialize LPIPS model
self.lpips_model = lpips.LPIPS(net='alex').to(device)
# Initialize Inception model for FID
self.inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
self.inception_model.eval()
# Preprocessing for Inception
self.inception_transform = torch.nn.Sequential(
Resize((299, 299)),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
def calculate_psnr(self, pred, target, mask=None):
"""Calculate PSNR between prediction and target"""
if mask is not None:
# Only calculate PSNR in masked regions
pred_masked = pred * mask
target_masked = target * mask
# Convert to numpy and calculate PSNR for each frame
psnr_values = []
for i in range(pred.shape[0]): # Over batch or sequence
if pred.dim() == 5: # (B, C, N, H, W)
for j in range(pred.shape[2]): # Over frames
p = pred_masked[i, :, j].permute(1, 2, 0).cpu().numpy()
t = target_masked[i, :, j].permute(1, 2, 0).cpu().numpy()
m = mask[i, 0, j].cpu().numpy()
if m.sum() > 0: # Only if there are masked pixels
psnr_val = psnr(t, p, data_range=1.0)
psnr_values.append(psnr_val)
else: # (B, C, H, W)
p = pred_masked[i].permute(1, 2, 0).cpu().numpy()
t = target_masked[i].permute(1, 2, 0).cpu().numpy()
m = mask[i, 0].cpu().numpy()
if m.sum() > 0:
psnr_val = psnr(t, p, data_range=1.0)
psnr_values.append(psnr_val)
else:
# Calculate PSNR for entire image
mse = F.mse_loss(pred, target)
psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse))
return psnr_val.item()
return np.mean(psnr_values) if psnr_values else 0.0
def calculate_ssim(self, pred, target, mask=None):
"""Calculate SSIM between prediction and target"""
ssim_values = []
for i in range(pred.shape[0]): # Over batch
if pred.dim() == 5: # (B, C, N, H, W)
for j in range(pred.shape[2]): # Over frames
p = pred[i, :, j].permute(1, 2, 0).cpu().numpy()
t = target[i, :, j].permute(1, 2, 0).cpu().numpy()
if mask is not None:
m = mask[i, 0, j].cpu().numpy()
if m.sum() == 0:
continue
ssim_val = ssim(t, p, data_range=1.0, channel_axis=2)
ssim_values.append(ssim_val)
else: # (B, C, H, W)
p = pred[i].permute(1, 2, 0).cpu().numpy()
t = target[i].permute(1, 2, 0).cpu().numpy()
if mask is not None:
m = mask[i, 0].cpu().numpy()
if m.sum() == 0:
continue
ssim_val = ssim(t, p, data_range=1.0, channel_axis=2)
ssim_values.append(ssim_val)
return np.mean(ssim_values) if ssim_values else 0.0
def calculate_lpips(self, pred, target, mask=None):
"""Calculate LPIPS perceptual distance"""
# Ensure inputs are in [-1, 1] range for LPIPS
pred_norm = pred * 2.0 - 1.0
target_norm = target * 2.0 - 1.0
lpips_values = []
if pred.dim() == 5: # (B, C, N, H, W)
for i in range(pred.shape[0]):
for j in range(pred.shape[2]):
p = pred_norm[i, :, j].unsqueeze(0)
t = target_norm[i, :, j].unsqueeze(0)
with torch.no_grad():
lpips_val = self.lpips_model(p, t)
lpips_values.append(lpips_val.item())
else: # (B, C, H, W)
with torch.no_grad():
lpips_val = self.lpips_model(pred_norm, target_norm)
lpips_values.extend(lpips_val.cpu().numpy().tolist())
return np.mean(lpips_values) if lpips_values else 0.0
def calculate_iou(self, pred_mask, target_mask, threshold=0.5):
"""Calculate IoU for binary masks"""
pred_binary = (pred_mask > threshold).float()
target_binary = (target_mask > threshold).float()
intersection = (pred_binary * target_binary).sum()
union = pred_binary.sum() + target_binary.sum() - intersection
iou = intersection / (union + 1e-8)
return iou.item()
def get_inception_features(self, images):
"""Extract features from Inception model for FID calculation"""
with torch.no_grad():
# Preprocess images
images_preprocessed = self.inception_transform(images)
# Get features
features = self.inception_model(images_preprocessed)
return features.cpu().numpy()
def calculate_fid(self, pred, target):
"""Calculate Fréchet Inception Distance"""
# Reshape if needed
if pred.dim() == 5: # (B, C, N, H, W) -> (B*N, C, H, W)
pred = pred.permute(0, 2, 1, 3, 4).reshape(-1, pred.shape[1], pred.shape[3], pred.shape[4])
target = target.permute(0, 2, 1, 3, 4).reshape(-1, target.shape[1], target.shape[3], target.shape[4])
# Get features
pred_features = self.get_inception_features(pred)
target_features = self.get_inception_features(target)
# Calculate statistics
mu1, sigma1 = pred_features.mean(axis=0), np.cov(pred_features, rowvar=False)
mu2, sigma2 = target_features.mean(axis=0), np.cov(target_features, rowvar=False)
# Calculate FID
diff = mu1 - mu2
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
return fid
def calculate_all_metrics(self, pred, target, amodal_mask=None):
"""Calculate all metrics at once"""
metrics = {}
metrics['psnr'] = self.calculate_psnr(pred, target, amodal_mask)
metrics['ssim'] = self.calculate_ssim(pred, target, amodal_mask)
metrics['lpips'] = self.calculate_lpips(pred, target, amodal_mask)
try:
metrics['fid'] = self.calculate_fid(pred, target)
except:
metrics['fid'] = 0.0
# IoU for masks (if available)
if amodal_mask is not None:
# Create predicted mask by thresholding prediction
pred_intensity = pred.mean(dim=1, keepdim=True) # Convert to grayscale
metrics['iou'] = self.calculate_iou(pred_intensity, amodal_mask)
return metrics
# Add this function to create error heatmaps
def create_error_heatmap(pred, target, mask=None):
"""Create error heatmap between prediction and target"""
# Calculate per-pixel error
error = torch.abs(pred - target).mean(dim=0) # Average over color channels
if mask is not None:
error = error * mask.squeeze()
return error.cpu().numpy()
# Enhanced training function with metrics and wandb
def train_video_amodal_with_metrics():
# Initialize wandb
wandb.init(
project="video-amodal-completion",
config={
'batch_size': 2,
'seq_len': 6,
'img_size': (256, 256),
'num_epochs': 30,
'learning_rate': 5e-5,
'max_scenes': 2,
'samples_per_scene': 2,
'num_workers': 2,
'grad_accum_steps': 4
}
)
#print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
config = wandb.config
# Initialize metrics calculator
metrics_calculator = VideoAmodalMetrics(device)
# Create datasets (your existing code)
train_dataset = VideoAmodalDataset(
root_dir='data',
split='train',
seq_len=config.seq_len,
img_size=config.img_size,
max_scenes=config.max_scenes,
samples_per_scene=config.samples_per_scene,
max_samples=100
)
val_dataset = VideoAmodalDataset(
root_dir='data',
split='test',
seq_len=config.seq_len,
img_size=config.img_size,
max_scenes=1,
samples_per_scene=1,
max_samples=10
)
# DataLoaders (your existing code)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=1
)
# Model (your existing code)
model = Video3DUNet(
in_channels=5,
out_channels=3,
sequence_length=config.seq_len
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4)
criterion = VideoAmodalCompletionLoss()
# Training loop with metrics
for epoch in range(config.num_epochs):
model.train()
epoch_losses = []
epoch_metrics = {
'train_psnr': [],
'train_ssim': [],
'train_lpips': [],
'train_fid': [],
'train_iou': []
}
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
# Prepare inputs and targets (your existing code)
inputs = prepare_model_input(batch).to(device, non_blocking=True)
targets = prepare_model_target(batch).to(device, non_blocking=True)
modal_masks = batch['modal_masks'].to(device, non_blocking=True)
amodal_masks = batch['amodal_masks'].to(device, non_blocking=True)
# Forward pass (your existing code)
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks)
loss = loss / config.grad_accum_steps
# Backward pass (your existing code)
loss.backward()
# Calculate metrics periodically
if i % 10 == 0:
with torch.no_grad():
amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4)
batch_metrics = metrics_calculator.calculate_all_metrics(
outputs, targets, amodal_masks_3d
)
for key, value in batch_metrics.items():
if f'train_{key}' in epoch_metrics:
epoch_metrics[f'train_{key}'].append(value)
# Gradient accumulation (your existing code)
if (i + 1) % config.grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
torch.cuda.empty_cache()
epoch_losses.append(loss_dict['total_loss'])
# Periodic logging with wandb
if i % 20 == 0:
log_dict = {
'batch': epoch * len(train_loader) + i,
'train_loss': loss_dict['total_loss'],
'train_visible_loss': loss_dict['visible_loss'],
'train_occluded_loss': loss_dict['occluded_loss'],
'train_background_loss': loss_dict['background_loss'],
'train_boundary_loss': loss_dict['boundary_loss']
}
# Add latest metrics if available
for key, values in epoch_metrics.items():
if values:
log_dict[key] = values[-1]
wandb.log(log_dict)
print(f"Batch {i}, Loss: {loss_dict['total_loss']:.4f}")
print(f" Visible: {loss_dict['visible_loss']:.4f}, "
f"Occluded: {loss_dict['occluded_loss']:.4f}, "
f"Background: {loss_dict['background_loss']:.4f}")
# Validation with metrics
model.eval()
val_losses = []
val_metrics = {
'val_psnr': [],
'val_ssim': [],
'val_lpips': [],
'val_fid': [],
'val_iou': []
}
with torch.no_grad():
for batch in val_loader:
inputs = prepare_model_input(batch).to(device)
targets = prepare_model_target(batch).to(device)
modal_masks = batch['modal_masks'].to(device)
amodal_masks = batch['amodal_masks'].to(device)
outputs = model(inputs)
loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks)
val_losses.append(loss_dict['total_loss'])
# Calculate validation metrics
amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4)
batch_metrics = metrics_calculator.calculate_all_metrics(
outputs, targets, amodal_masks_3d
)
for key, value in batch_metrics.items():
if f'val_{key}' in val_metrics:
val_metrics[f'val_{key}'].append(value)
# End of epoch logging
avg_train_loss = np.mean(epoch_losses)
avg_val_loss = np.mean(val_losses)
epoch_log = {
'epoch': epoch,
'avg_train_loss': avg_train_loss,
'avg_val_loss': avg_val_loss
}
# Add averaged metrics
for key, values in {**epoch_metrics, **val_metrics}.items():
if values:
epoch_log[f'avg_{key}'] = np.mean(values)
wandb.log(epoch_log)
print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
# Log metrics
for key, values in val_metrics.items():
if values:
print(f" {key}: {np.mean(values):.4f}")
# Save checkpoint (your existing code)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': avg_train_loss,
'val_loss': avg_val_loss,
'metrics': {key: np.mean(values) for key, values in val_metrics.items() if values}
}, f"epoch_{epoch}.pth")
wandb.finish()
# Enhanced GIF creation with error heatmap
def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks,
output_path="amodal_completion_with_error.gif", duration=200):
"""Create animated GIF with error heatmap"""
from PIL import Image
import numpy as np
frames = []
all_errors = []
# Calculate errors for all frames first to get consistent color scale
for i in range(len(predictions)):
pred_tensor = predictions[i]
gt_tensor = gt_amodal_frames[i]
mask_tensor = amodal_masks[i] if amodal_masks else None
error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0),
mask_tensor.unsqueeze(0) if mask_tensor is not None else None)
all_errors.append(error)
# Get global error range for consistent coloring
max_error = max(error.max() for error in all_errors)
min_error = min(error.min() for error in all_errors)
for i in range(len(predictions)):
# Scene input
scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# Prediction output
pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8)
# Ground truth amodal
gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# Error heatmap
# Error heatmap
error = all_errors[i]
# Normalize error to [0, 1] using global range
if max_error > min_error:
error_normalized = (error - min_error) / (max_error - min_error)
else:
error_normalized = error
# Ensure error is shape (H, W) before applying colormap
error_normalized = np.squeeze(error_normalized)
if error_normalized.ndim == 3:
error_normalized = error_normalized[0]
# Apply colormap
error_colored = cm.jet(error_normalized) # (H, W, 4)
error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3)
# Now safe to concatenate
combined = np.concatenate([scene_rgb, pred_rgb, gt_rgb, error_rgb], axis=1)
# Add error scale text (simplified - you might want to add a proper colorbar)
from PIL import ImageDraw, ImageFont
img_pil = Image.fromarray(combined)
draw = ImageDraw.Draw(img_pil)
# Add text with error range
try:
font = ImageFont.load_default()
except:
font = None
text = f"Error: {min_error:.3f} - {max_error:.3f}"
draw.text((combined.shape[1] - 150, 10), text, fill=(255, 255, 255), font=font)
frames.append(img_pil)
# Save as animated GIF
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
duration=duration,
loop=0
)
print(f"GIF with error heatmap saved to {output_path}")
print(f"Error range: {min_error:.4f} to {max_error:.4f}")
# Enhanced video generation with metrics
def load_model_and_generate_video_with_metrics(checkpoint_path, dataset, device,
output_path="amodal_completion.mp4", fps=8):
"""Load trained model and generate video with metrics calculation"""
import cv2
from pathlib import Path
# Initialize metrics calculator
metrics_calculator = VideoAmodalMetrics(device)
# Load model (your existing code remains the same)
model = Video3DUNet(in_channels=5, out_channels=3, sequence_length=8).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}")
# Get a sample with 24 frames (your existing code)
sample = dataset[0]
seq_len = 8
total_frames = len(sample['rgb_sequence'])
print(f"Processing {total_frames} frames in windows of {seq_len}")
all_predictions = []
all_rgb = []
all_modal_masks = []
all_amodal_masks = []
all_metrics = []
with torch.no_grad():
# Process overlapping windows (your existing code)
for start_idx in range(0, total_frames - seq_len + 1, seq_len//2):
end_idx = min(start_idx + seq_len, total_frames)
# Create batch for this window
window_batch = {}
for key, value in sample.items():
if isinstance(value, torch.Tensor):
if value.dim() == 4:
window_batch[key] = value[start_idx:end_idx].unsqueeze(0)
else:
window_batch[key] = value.unsqueeze(0)
else:
window_batch[key] = [value]
# Get prediction for this window
inputs = prepare_model_input(window_batch).to(device)
pred = model(inputs)
# Mask to object region
amodal_mask = window_batch['amodal_masks'].permute(0, 2, 1, 3, 4).expand_as(pred).to(device)
pred_masked = pred * amodal_mask
# Calculate metrics for this window
target = prepare_model_target(window_batch).to(device)
window_metrics = metrics_calculator.calculate_all_metrics(pred, target, amodal_mask)
all_metrics.append(window_metrics)
# Store results (your existing code)
pred_frames = pred_masked.squeeze(0).permute(1, 0, 2, 3).cpu()
if start_idx == 0:
all_predictions.extend([pred_frames[i] for i in range(len(pred_frames))])
else:
overlap_frames = seq_len // 2
for i in range(overlap_frames):
if len(all_predictions) > start_idx + i:
all_predictions[start_idx + i] = (all_predictions[start_idx + i] + pred_frames[i]) / 2.0
for i in range(overlap_frames, len(pred_frames)):
if start_idx + i < total_frames:
all_predictions.append(pred_frames[i])
if start_idx == 0:
all_rgb = [sample['rgb_sequence'][i] for i in range(total_frames)]
all_modal_masks = [sample['modal_masks'][i] for i in range(total_frames)]
all_amodal_masks = [sample['amodal_masks'][i] for i in range(total_frames)]
all_gt_amodal = [sample['amodal_rgb_sequence'][i] for i in range(total_frames)]
# Print overall metrics
print("\nOverall Metrics:")
avg_metrics = {}
for key in all_metrics[0].keys():
avg_metrics[key] = np.mean([m[key] for m in all_metrics])
print(f" {key.upper()}: {avg_metrics[key]:.4f}")
# Your existing video creation code remains the same
all_predictions = all_predictions[:total_frames]
print(f"Generated {len(all_predictions)} prediction frames")
# Create video (your existing code)
height, width = all_predictions[0].shape[-2:]
video_width = width * 4
video_height = height
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (video_width, video_height))
for i in range(len(all_predictions)):
scene_rgb = all_rgb[i].permute(1, 2, 0).numpy()
modal_mask = all_modal_masks[i][0].numpy()
modal_mask_rgb = np.stack([modal_mask, modal_mask, modal_mask], axis=2)
pred_rgb = all_predictions[i].permute(1, 2, 0).numpy()
pred_rgb = np.clip(pred_rgb, 0, 1)
try:
gt_amodal = sample['amodal_rgb_sequence'][i].permute(1, 2, 0).numpy()
amodal_mask_np = all_amodal_masks[i][0].numpy()
gt_amodal_masked = gt_amodal * amodal_mask_np[:, :, None]
except:
gt_amodal_masked = np.zeros_like(pred_rgb)
combined_frame = np.concatenate([
scene_rgb,
modal_mask_rgb,
pred_rgb,
gt_amodal_masked
], axis=1)
combined_frame_bgr = cv2.cvtColor((combined_frame * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
out.write(combined_frame_bgr)
if i % 5 == 0:
print(f"Processed frame {i+1}/{len(all_predictions)}")
out.release()
print(f"Video saved to {output_path}")
return all_predictions, all_rgb, all_gt_amodal, all_amodal_masks, avg_metrics
# Enhanced run function with all new features
def run_enhanced_video_generation():
"""Run video generation with metrics and error visualization"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load dataset
dataset = VideoAmodalDataset(
root_dir='data',
split='test',
seq_len=24,
img_size=(256, 256),
max_scenes=1,
samples_per_scene=1,
max_samples=1
)
# Generate video with metrics
checkpoint_path = "video_amodal_model_epoch_4.pth"
predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics(
checkpoint_path,
dataset,
device,
output_path="amodal_completion_video_with_metrics.mp4",
fps=8
)
# Create enhanced GIF with error heatmap
create_gif_with_error_heatmap(
predictions,
rgb_frames,
gt_amodal_frames,
amodal_masks,
output_path="amodal_completion_with_error.gif",
duration=150
)
print("Enhanced video generation complete!")
return metrics
train_video_amodal_with_metrics()
# Simple way to run GIF generation from your trained model
import torch
def run_gif_generation():
"""Simple function to generate GIFs from your trained model"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create test dataset
dataset = VideoAmodalDataset(
root_dir='data',
split='test',
seq_len=24,
img_size=(256, 256),
max_scenes=50,
samples_per_scene=5,
max_samples=50
)
# Generate video with metrics and error heatmap GIF
checkpoint_path = "epoch_29.pth" # Change this to your checkpoint file name
predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics(
checkpoint_path,
dataset,
device,
output_path="amodal_completion_video.mp4",
fps=6
)
# Create GIF with error heatmap
create_gif_with_error_heatmap(
predictions,
rgb_frames,
gt_amodal_frames,
amodal_masks,
output_path="amodal_completion_with_error.gif",
duration=150
)
print("GIF creation complete!")
print(f"Metrics: {metrics}")
# Just run this:
if __name__ == "__main__":
run_gif_generation()
import cv2
def draw_amodal_boundary(rgb_image, amodal_mask, color=(255, 0, 255)):
contours, _ = cv2.findContours(amodal_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
outlined = rgb_image.copy()
cv2.drawContours(outlined, contours, -1, color, thickness=2)
return outlined
# Enhanced GIF creation with proper error heatmap and colorbar
def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks,
output_path="amodal_completion_with_error.gif", duration=240):
"""Create animated GIF with proper error heatmap and colorbar"""
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import io
frames = []
all_errors = []
# Calculate errors for all frames first to get consistent color scale
for i in range(len(predictions)):
pred_tensor = predictions[i]
gt_tensor = gt_amodal_frames[i]
mask_tensor = amodal_masks[i] if amodal_masks else None
error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0),
mask_tensor.unsqueeze(0) if mask_tensor is not None else None)
all_errors.append(error)
# Get global error range for consistent coloring
# Focus on masked regions only for better visualization
masked_errors = []
for i, error in enumerate(all_errors):
if amodal_masks is not None:
mask = amodal_masks[i][0].numpy()
masked_error = error * mask
masked_errors.extend(masked_error[masked_error > 0]) # Only non-zero masked regions
else:
masked_errors.extend(error.flatten())
if masked_errors:
# Use percentiles for better visualization (removes outliers)
min_error = np.percentile(masked_errors, 5) # 5th percentile
max_error = np.percentile(masked_errors, 95) # 95th percentile
else:
min_error = min(error.min() for error in all_errors)
max_error = max(error.max() for error in all_errors)
# Ensure we have a reasonable range
if max_error - min_error < 1e-6:
max_error = min_error + 1e-6
print(f"Error range for visualization: {min_error:.4f} to {max_error:.4f}")
# Create colorbar image
def create_colorbar(height=256, width=30):
# Create a vertical gradient
gradient = np.linspace(1, 0, height).reshape(-1, 1)
gradient = np.repeat(gradient, width, axis=1)
# Apply colormap (using 'hot' for red-yellow-white like your image)
cmap = cm.get_cmap('hot')
colorbar_colored = cmap(gradient)
colorbar_rgb = (colorbar_colored[:, :, :3] * 255).astype(np.uint8)
# Convert to PIL Image
colorbar_img = Image.fromarray(colorbar_rgb)
# Add scale labels
fig, ax = plt.subplots(figsize=(1, 4))
fig.patch.set_facecolor('black')
ax.set_facecolor('black')
# Create colorbar
norm = Normalize(vmin=min_error, vmax=max_error)
sm = cm.ScalarMappable(norm=norm, cmap='hot')
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=1.0)
cbar.set_label('Prediction Error', color='white', fontsize=10)
cbar.ax.tick_params(colors='white', labelsize=8)
# Remove the main axes
ax.remove()
# Save to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight',
facecolor='black', edgecolor='none', dpi=100)
buf.seek(0)
colorbar_with_labels = Image.open(buf)
plt.close()
return colorbar_with_labels
# Create colorbar once
colorbar_img = create_colorbar()
colorbar_width = colorbar_img.width
for i in range(len(predictions)):
# Scene input
scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# Prediction output
pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8)
# Ground truth amodal
gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
# Error heatmap
error = all_errors[i]
# Apply mask to error if available
if amodal_masks is not None:
mask = amodal_masks[i][0].numpy()
error = error * mask
# Ensure error is shape (H, W)
error = np.squeeze(error)
if error.ndim == 3:
error = error[0]
# Normalize error using global range
error_normalized = np.clip((error - min_error) / (max_error - min_error), 0, 1)
# Apply 'hot' colormap for red-yellow-white heatmap like your image
cmap = cm.get_cmap('hot')
error_colored = cmap(error_normalized) # (H, W, 4)
error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) # (H, W, 3)
# Set non-masked regions to black for better visualization
if amodal_masks is not None:
mask_3d = np.stack([mask, mask, mask], axis=2)
error_rgb = error_rgb * mask_3d.astype(np.uint8)
# Concatenate all images
highlighted_rgb = draw_amodal_boundary(scene_rgb, amodal_masks[i][0].cpu().numpy())
combined = np.concatenate([highlighted_rgb, pred_rgb, gt_rgb, error_rgb], axis=1)
# Convert to PIL for adding colorbar
img_pil = Image.fromarray(combined)
# Resize colorbar to match image height
colorbar_resized = colorbar_img.resize((colorbar_width, img_pil.height))
# Create final image with colorbar
final_width = img_pil.width + colorbar_width + 10 # 10px spacing
final_img = Image.new('RGB', (final_width, img_pil.height), color='black')
# Paste main image and colorbar
final_img.paste(img_pil, (0, 0))
final_img.paste(colorbar_resized, (img_pil.width + 10, 0))
# Add frame number
draw = ImageDraw.Draw(final_img)
try:
font = ImageFont.load_default()
except:
font = None
frame_text = f"Frame {i+1}/{len(predictions)}"
draw.text((10, 10), frame_text, fill=(0, 0, 0), font=font)
frames.append(final_img)
# Save as animated GIF
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
duration=duration,
loop=0
)
print(f"GIF with proper error heatmap saved to {output_path}")
print(f"Error range: {min_error:.4f} to {max_error:.4f}")
print(f"Colorbar shows errors from low (black/red) to high (yellow/white)")
# Also update the error heatmap calculation to be more sensitive
def create_error_heatmap(pred, target, mask=None):
"""Create error heatmap between prediction and target with enhanced sensitivity"""
# Calculate per-pixel error (L2 norm across color channels)
error = torch.sqrt(torch.sum((pred - target) ** 2, dim=1)) # L2 error per pixel
# Alternative: Use L1 error for different characteristics
# error = torch.abs(pred - target).mean(dim=1) # L1 error
if mask is not None:
error = error * mask.squeeze()
return error.cpu().numpy()