|
|
|
|
|
"""2.2.2.2.2.ipynb |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/1igY4MKIJJTPHgEkdLFI_T5H6sLUoTaLr |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
"""## CODE""" |
|
|
|
|
|
pip install torchmetrics lpips |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
|
from torchmetrics.image.fid import FrechetInceptionDistance |
|
|
import lpips |
|
|
import os |
|
|
import random |
|
|
import shutil |
|
|
from huggingface_hub import HfApi, hf_hub_download |
|
|
import tarfile |
|
|
import json |
|
|
import cv2 |
|
|
from tqdm import tqdm |
|
|
|
|
|
def download_sequential_data(repo_id="Amar-S/MOVi-MC-AC", sample_ratio=0.01, base_dir="/content/data"): |
|
|
""" |
|
|
Download data while preserving video sequences |
|
|
""" |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
os.makedirs(f"{base_dir}/train", exist_ok=True) |
|
|
os.makedirs(f"{base_dir}/test", exist_ok=True) |
|
|
|
|
|
|
|
|
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
|
|
|
|
|
|
|
|
|
|
|
test_files = [f for f in files if f.startswith("test/") and f.endswith(".tar.gz")] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subset_test = random.sample(test_files, max(1, int(len(test_files) * sample_ratio))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for file in subset_test: |
|
|
print(f"Downloading {file}...") |
|
|
out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
|
|
dest_path = f"{base_dir}/test/{os.path.basename(file)}" |
|
|
shutil.copyfile(out_path, dest_path) |
|
|
|
|
|
|
|
|
extract_archives(f"{base_dir}/train") |
|
|
extract_archives(f"{base_dir}/test") |
|
|
|
|
|
print("Download and extraction complete!") |
|
|
|
|
|
def extract_archives(directory): |
|
|
"""Extract all tar.gz files in a directory""" |
|
|
for file in os.listdir(directory): |
|
|
if file.endswith(".tar.gz"): |
|
|
filepath = os.path.join(directory, file) |
|
|
print(f"Extracting {filepath}...") |
|
|
with tarfile.open(filepath, 'r:gz') as tar: |
|
|
tar.extractall(path=directory) |
|
|
|
|
|
os.remove(filepath) |
|
|
|
|
|
download_sequential_data() |
|
|
|
|
|
extract_archives('/content/data/test') |
|
|
|
|
|
def extract_archives(directory): |
|
|
"""Extract all tar.gz files in a directory""" |
|
|
for file in os.listdir(directory): |
|
|
if file.endswith(".tar.gz"): |
|
|
filepath = os.path.join(directory, file) |
|
|
print(f"Extracting {filepath}...") |
|
|
with tarfile.open(filepath, 'r:gz') as tar: |
|
|
print(filepath) |
|
|
tar.extractall(path=directory) |
|
|
|
|
|
os.remove(filepath) |
|
|
|
|
|
|
|
|
extract_archives('/content/data/test') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoAmodalDataset(Dataset): |
|
|
def __init__(self, root_dir, split='train', seq_len=8, img_size=(256,256), |
|
|
max_scenes=4, samples_per_scene=3, max_samples=None): |
|
|
self.root_dir = Path(root_dir) |
|
|
self.split = split |
|
|
self.seq_len = seq_len |
|
|
self.img_size = img_size |
|
|
self.max_scenes = max_scenes |
|
|
self.samples_per_scene = samples_per_scene |
|
|
|
|
|
self.samples = self._build_sample_index(max_samples) |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize(img_size), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
def _build_sample_index(self, max_samples): |
|
|
samples = [] |
|
|
scene_paths = sorted((self.root_dir / self.split).glob('scene_*'))[:self.max_scenes] |
|
|
|
|
|
for scene_path in scene_paths: |
|
|
camera_paths = sorted(scene_path.glob('camera_*')) |
|
|
|
|
|
for camera_path in camera_paths: |
|
|
obj_paths = sorted(camera_path.glob('obj_*')) |
|
|
selected_objs = random.sample(obj_paths, min(self.samples_per_scene, len(obj_paths))) |
|
|
|
|
|
for obj_path in selected_objs: |
|
|
rgba_files = sorted(camera_path.glob('rgba_*.png')) |
|
|
frame_ids = [int(p.stem.split('_')[1]) for p in rgba_files] |
|
|
|
|
|
|
|
|
for i in range(0, len(frame_ids) - self.seq_len + 1, self.seq_len): |
|
|
samples.append({ |
|
|
'scene': scene_path.name, |
|
|
'camera': camera_path.name, |
|
|
'obj_folder': obj_path.name, |
|
|
'frame_ids': frame_ids[i:i+self.seq_len], |
|
|
'obj_id': int(obj_path.name.split('_')[1]) |
|
|
}) |
|
|
|
|
|
if max_samples and len(samples) >= max_samples: |
|
|
return samples |
|
|
|
|
|
return samples |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sample = self.samples[idx] |
|
|
base_path = self.root_dir / self.split / sample['scene'] / sample['camera'] |
|
|
obj_path = base_path / sample['obj_folder'] |
|
|
|
|
|
rgb_frames = [] |
|
|
modal_mask_frames = [] |
|
|
amodal_mask_frames = [] |
|
|
amodal_rgb_frames = [] |
|
|
|
|
|
for fid in sample['frame_ids']: |
|
|
fid_str = f"{fid:05d}" |
|
|
|
|
|
try: |
|
|
|
|
|
rgb = Image.open(base_path / f'rgba_{fid_str}.png').convert('RGB') |
|
|
rgb = self.transform(rgb) |
|
|
|
|
|
|
|
|
seg_map = np.array(Image.open(base_path / f'segmentation_{fid_str}.png')) |
|
|
modal_mask_np = (seg_map == sample['obj_id']).astype(np.uint8) * 255 |
|
|
modal_mask = Image.fromarray(modal_mask_np, mode='L') |
|
|
modal_mask = self.transform(modal_mask) |
|
|
|
|
|
|
|
|
amodal_mask = Image.open(obj_path / f'segmentation_{fid_str}.png').convert('L') |
|
|
amodal_mask = self.transform(amodal_mask) |
|
|
|
|
|
|
|
|
amodal_rgb = Image.open(obj_path / f'rgba_{fid_str}.png').convert('RGB') |
|
|
amodal_rgb = self.transform(amodal_rgb) |
|
|
|
|
|
rgb_frames.append(rgb) |
|
|
modal_mask_frames.append(modal_mask) |
|
|
amodal_mask_frames.append(amodal_mask) |
|
|
amodal_rgb_frames.append(amodal_rgb) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading {base_path}/rgba_{fid_str}.png: {e}") |
|
|
|
|
|
empty_rgb = torch.zeros(3, self.img_size[0], self.img_size[1]) |
|
|
empty_mask = torch.zeros(1, self.img_size[0], self.img_size[1]) |
|
|
|
|
|
return { |
|
|
'rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
|
|
'modal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
|
|
'amodal_masks': empty_mask.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
|
|
'amodal_rgb_sequence': empty_rgb.unsqueeze(0).repeat(self.seq_len, 1, 1, 1), |
|
|
'scene': sample['scene'], |
|
|
'camera': sample['camera'], |
|
|
'object_id': sample['obj_id'] |
|
|
} |
|
|
|
|
|
return { |
|
|
'rgb_sequence': torch.stack(rgb_frames), |
|
|
'modal_masks': torch.stack(modal_mask_frames), |
|
|
'amodal_masks': torch.stack(amodal_mask_frames), |
|
|
'amodal_rgb_sequence': torch.stack(amodal_rgb_frames), |
|
|
'scene': sample['scene'], |
|
|
'camera': sample['camera'], |
|
|
'object_id': sample['obj_id'] |
|
|
} |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
import wandb |
|
|
|
|
|
wandb.login() |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from skimage.metrics import structural_similarity as ssim |
|
|
from skimage.metrics import peak_signal_noise_ratio as psnr |
|
|
import torch.nn.functional as F |
|
|
from scipy import linalg |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.cm as cm |
|
|
from torchvision.models import inception_v3 |
|
|
from torchvision.transforms import Resize, Normalize |
|
|
import lpips |
|
|
|
|
|
|
|
|
class VideoAmodalMetrics: |
|
|
"""Compute various metrics for video amodal completion""" |
|
|
|
|
|
def __init__(self, device='cuda'): |
|
|
self.device = device |
|
|
|
|
|
self.lpips_model = lpips.LPIPS(net='alex').to(device) |
|
|
|
|
|
|
|
|
self.inception_model = inception_v3(pretrained=True, transform_input=False).to(device) |
|
|
self.inception_model.eval() |
|
|
|
|
|
|
|
|
self.inception_transform = torch.nn.Sequential( |
|
|
Resize((299, 299)), |
|
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
) |
|
|
|
|
|
def calculate_psnr(self, pred, target, mask=None): |
|
|
"""Calculate PSNR between prediction and target""" |
|
|
if mask is not None: |
|
|
|
|
|
pred_masked = pred * mask |
|
|
target_masked = target * mask |
|
|
|
|
|
|
|
|
psnr_values = [] |
|
|
for i in range(pred.shape[0]): |
|
|
if pred.dim() == 5: |
|
|
for j in range(pred.shape[2]): |
|
|
p = pred_masked[i, :, j].permute(1, 2, 0).cpu().numpy() |
|
|
t = target_masked[i, :, j].permute(1, 2, 0).cpu().numpy() |
|
|
m = mask[i, 0, j].cpu().numpy() |
|
|
|
|
|
if m.sum() > 0: |
|
|
psnr_val = psnr(t, p, data_range=1.0) |
|
|
psnr_values.append(psnr_val) |
|
|
else: |
|
|
p = pred_masked[i].permute(1, 2, 0).cpu().numpy() |
|
|
t = target_masked[i].permute(1, 2, 0).cpu().numpy() |
|
|
m = mask[i, 0].cpu().numpy() |
|
|
|
|
|
if m.sum() > 0: |
|
|
psnr_val = psnr(t, p, data_range=1.0) |
|
|
psnr_values.append(psnr_val) |
|
|
else: |
|
|
|
|
|
mse = F.mse_loss(pred, target) |
|
|
psnr_val = 20 * torch.log10(1.0 / torch.sqrt(mse)) |
|
|
return psnr_val.item() |
|
|
|
|
|
return np.mean(psnr_values) if psnr_values else 0.0 |
|
|
|
|
|
def calculate_ssim(self, pred, target, mask=None): |
|
|
"""Calculate SSIM between prediction and target""" |
|
|
ssim_values = [] |
|
|
|
|
|
for i in range(pred.shape[0]): |
|
|
if pred.dim() == 5: |
|
|
for j in range(pred.shape[2]): |
|
|
p = pred[i, :, j].permute(1, 2, 0).cpu().numpy() |
|
|
t = target[i, :, j].permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
if mask is not None: |
|
|
m = mask[i, 0, j].cpu().numpy() |
|
|
if m.sum() == 0: |
|
|
continue |
|
|
|
|
|
ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) |
|
|
ssim_values.append(ssim_val) |
|
|
else: |
|
|
p = pred[i].permute(1, 2, 0).cpu().numpy() |
|
|
t = target[i].permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
if mask is not None: |
|
|
m = mask[i, 0].cpu().numpy() |
|
|
if m.sum() == 0: |
|
|
continue |
|
|
|
|
|
ssim_val = ssim(t, p, data_range=1.0, channel_axis=2) |
|
|
ssim_values.append(ssim_val) |
|
|
|
|
|
return np.mean(ssim_values) if ssim_values else 0.0 |
|
|
|
|
|
def calculate_lpips(self, pred, target, mask=None): |
|
|
"""Calculate LPIPS perceptual distance""" |
|
|
|
|
|
pred_norm = pred * 2.0 - 1.0 |
|
|
target_norm = target * 2.0 - 1.0 |
|
|
|
|
|
lpips_values = [] |
|
|
|
|
|
if pred.dim() == 5: |
|
|
for i in range(pred.shape[0]): |
|
|
for j in range(pred.shape[2]): |
|
|
p = pred_norm[i, :, j].unsqueeze(0) |
|
|
t = target_norm[i, :, j].unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
lpips_val = self.lpips_model(p, t) |
|
|
lpips_values.append(lpips_val.item()) |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
lpips_val = self.lpips_model(pred_norm, target_norm) |
|
|
lpips_values.extend(lpips_val.cpu().numpy().tolist()) |
|
|
|
|
|
return np.mean(lpips_values) if lpips_values else 0.0 |
|
|
|
|
|
def calculate_iou(self, pred_mask, target_mask, threshold=0.5): |
|
|
"""Calculate IoU for binary masks""" |
|
|
pred_binary = (pred_mask > threshold).float() |
|
|
target_binary = (target_mask > threshold).float() |
|
|
|
|
|
intersection = (pred_binary * target_binary).sum() |
|
|
union = pred_binary.sum() + target_binary.sum() - intersection |
|
|
|
|
|
iou = intersection / (union + 1e-8) |
|
|
return iou.item() |
|
|
|
|
|
def get_inception_features(self, images): |
|
|
"""Extract features from Inception model for FID calculation""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
images_preprocessed = self.inception_transform(images) |
|
|
|
|
|
|
|
|
features = self.inception_model(images_preprocessed) |
|
|
return features.cpu().numpy() |
|
|
|
|
|
def calculate_fid(self, pred, target): |
|
|
"""Calculate Fréchet Inception Distance""" |
|
|
|
|
|
if pred.dim() == 5: |
|
|
pred = pred.permute(0, 2, 1, 3, 4).reshape(-1, pred.shape[1], pred.shape[3], pred.shape[4]) |
|
|
target = target.permute(0, 2, 1, 3, 4).reshape(-1, target.shape[1], target.shape[3], target.shape[4]) |
|
|
|
|
|
|
|
|
pred_features = self.get_inception_features(pred) |
|
|
target_features = self.get_inception_features(target) |
|
|
|
|
|
|
|
|
mu1, sigma1 = pred_features.mean(axis=0), np.cov(pred_features, rowvar=False) |
|
|
mu2, sigma2 = target_features.mean(axis=0), np.cov(target_features, rowvar=False) |
|
|
|
|
|
|
|
|
diff = mu1 - mu2 |
|
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) |
|
|
if np.iscomplexobj(covmean): |
|
|
covmean = covmean.real |
|
|
|
|
|
fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean) |
|
|
return fid |
|
|
|
|
|
def calculate_all_metrics(self, pred, target, amodal_mask=None): |
|
|
"""Calculate all metrics at once""" |
|
|
metrics = {} |
|
|
|
|
|
metrics['psnr'] = self.calculate_psnr(pred, target, amodal_mask) |
|
|
metrics['ssim'] = self.calculate_ssim(pred, target, amodal_mask) |
|
|
metrics['lpips'] = self.calculate_lpips(pred, target, amodal_mask) |
|
|
|
|
|
try: |
|
|
metrics['fid'] = self.calculate_fid(pred, target) |
|
|
except: |
|
|
metrics['fid'] = 0.0 |
|
|
|
|
|
|
|
|
if amodal_mask is not None: |
|
|
|
|
|
pred_intensity = pred.mean(dim=1, keepdim=True) |
|
|
metrics['iou'] = self.calculate_iou(pred_intensity, amodal_mask) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def create_error_heatmap(pred, target, mask=None): |
|
|
"""Create error heatmap between prediction and target""" |
|
|
|
|
|
error = torch.abs(pred - target).mean(dim=0) |
|
|
|
|
|
if mask is not None: |
|
|
error = error * mask.squeeze() |
|
|
|
|
|
return error.cpu().numpy() |
|
|
|
|
|
|
|
|
def train_video_amodal_with_metrics(): |
|
|
|
|
|
wandb.init( |
|
|
project="video-amodal-completion", |
|
|
config={ |
|
|
'batch_size': 2, |
|
|
'seq_len': 6, |
|
|
'img_size': (256, 256), |
|
|
'num_epochs': 30, |
|
|
'learning_rate': 5e-5, |
|
|
'max_scenes': 2, |
|
|
'samples_per_scene': 2, |
|
|
'num_workers': 2, |
|
|
'grad_accum_steps': 4 |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
config = wandb.config |
|
|
|
|
|
|
|
|
metrics_calculator = VideoAmodalMetrics(device) |
|
|
|
|
|
|
|
|
train_dataset = VideoAmodalDataset( |
|
|
root_dir='data', |
|
|
split='train', |
|
|
seq_len=config.seq_len, |
|
|
img_size=config.img_size, |
|
|
max_scenes=config.max_scenes, |
|
|
samples_per_scene=config.samples_per_scene, |
|
|
max_samples=100 |
|
|
) |
|
|
|
|
|
val_dataset = VideoAmodalDataset( |
|
|
root_dir='data', |
|
|
split='test', |
|
|
seq_len=config.seq_len, |
|
|
img_size=config.img_size, |
|
|
max_scenes=1, |
|
|
samples_per_scene=1, |
|
|
max_samples=10 |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
num_workers=1 |
|
|
) |
|
|
|
|
|
|
|
|
model = Video3DUNet( |
|
|
in_channels=5, |
|
|
out_channels=3, |
|
|
sequence_length=config.seq_len |
|
|
).to(device) |
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-4) |
|
|
criterion = VideoAmodalCompletionLoss() |
|
|
|
|
|
|
|
|
for epoch in range(config.num_epochs): |
|
|
model.train() |
|
|
epoch_losses = [] |
|
|
epoch_metrics = { |
|
|
'train_psnr': [], |
|
|
'train_ssim': [], |
|
|
'train_lpips': [], |
|
|
'train_fid': [], |
|
|
'train_iou': [] |
|
|
} |
|
|
|
|
|
for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")): |
|
|
|
|
|
inputs = prepare_model_input(batch).to(device, non_blocking=True) |
|
|
targets = prepare_model_target(batch).to(device, non_blocking=True) |
|
|
modal_masks = batch['modal_masks'].to(device, non_blocking=True) |
|
|
amodal_masks = batch['amodal_masks'].to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = model(inputs) |
|
|
loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) |
|
|
loss = loss / config.grad_accum_steps |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if i % 10 == 0: |
|
|
with torch.no_grad(): |
|
|
amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) |
|
|
batch_metrics = metrics_calculator.calculate_all_metrics( |
|
|
outputs, targets, amodal_masks_3d |
|
|
) |
|
|
|
|
|
for key, value in batch_metrics.items(): |
|
|
if f'train_{key}' in epoch_metrics: |
|
|
epoch_metrics[f'train_{key}'].append(value) |
|
|
|
|
|
|
|
|
if (i + 1) % config.grad_accum_steps == 0: |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
epoch_losses.append(loss_dict['total_loss']) |
|
|
|
|
|
|
|
|
if i % 20 == 0: |
|
|
log_dict = { |
|
|
'batch': epoch * len(train_loader) + i, |
|
|
'train_loss': loss_dict['total_loss'], |
|
|
'train_visible_loss': loss_dict['visible_loss'], |
|
|
'train_occluded_loss': loss_dict['occluded_loss'], |
|
|
'train_background_loss': loss_dict['background_loss'], |
|
|
'train_boundary_loss': loss_dict['boundary_loss'] |
|
|
} |
|
|
|
|
|
|
|
|
for key, values in epoch_metrics.items(): |
|
|
if values: |
|
|
log_dict[key] = values[-1] |
|
|
|
|
|
wandb.log(log_dict) |
|
|
|
|
|
print(f"Batch {i}, Loss: {loss_dict['total_loss']:.4f}") |
|
|
print(f" Visible: {loss_dict['visible_loss']:.4f}, " |
|
|
f"Occluded: {loss_dict['occluded_loss']:.4f}, " |
|
|
f"Background: {loss_dict['background_loss']:.4f}") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_losses = [] |
|
|
val_metrics = { |
|
|
'val_psnr': [], |
|
|
'val_ssim': [], |
|
|
'val_lpips': [], |
|
|
'val_fid': [], |
|
|
'val_iou': [] |
|
|
} |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in val_loader: |
|
|
inputs = prepare_model_input(batch).to(device) |
|
|
targets = prepare_model_target(batch).to(device) |
|
|
modal_masks = batch['modal_masks'].to(device) |
|
|
amodal_masks = batch['amodal_masks'].to(device) |
|
|
|
|
|
outputs = model(inputs) |
|
|
loss, loss_dict = criterion(outputs, targets, modal_masks, amodal_masks) |
|
|
val_losses.append(loss_dict['total_loss']) |
|
|
|
|
|
|
|
|
amodal_masks_3d = amodal_masks.permute(0, 2, 1, 3, 4) |
|
|
batch_metrics = metrics_calculator.calculate_all_metrics( |
|
|
outputs, targets, amodal_masks_3d |
|
|
) |
|
|
|
|
|
for key, value in batch_metrics.items(): |
|
|
if f'val_{key}' in val_metrics: |
|
|
val_metrics[f'val_{key}'].append(value) |
|
|
|
|
|
|
|
|
avg_train_loss = np.mean(epoch_losses) |
|
|
avg_val_loss = np.mean(val_losses) |
|
|
|
|
|
epoch_log = { |
|
|
'epoch': epoch, |
|
|
'avg_train_loss': avg_train_loss, |
|
|
'avg_val_loss': avg_val_loss |
|
|
} |
|
|
|
|
|
|
|
|
for key, values in {**epoch_metrics, **val_metrics}.items(): |
|
|
if values: |
|
|
epoch_log[f'avg_{key}'] = np.mean(values) |
|
|
|
|
|
wandb.log(epoch_log) |
|
|
|
|
|
print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}") |
|
|
|
|
|
|
|
|
for key, values in val_metrics.items(): |
|
|
if values: |
|
|
print(f" {key}: {np.mean(values):.4f}") |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'train_loss': avg_train_loss, |
|
|
'val_loss': avg_val_loss, |
|
|
'metrics': {key: np.mean(values) for key, values in val_metrics.items() if values} |
|
|
}, f"epoch_{epoch}.pth") |
|
|
|
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, |
|
|
output_path="amodal_completion_with_error.gif", duration=200): |
|
|
"""Create animated GIF with error heatmap""" |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
frames = [] |
|
|
all_errors = [] |
|
|
|
|
|
|
|
|
for i in range(len(predictions)): |
|
|
pred_tensor = predictions[i] |
|
|
gt_tensor = gt_amodal_frames[i] |
|
|
mask_tensor = amodal_masks[i] if amodal_masks else None |
|
|
|
|
|
error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), |
|
|
mask_tensor.unsqueeze(0) if mask_tensor is not None else None) |
|
|
|
|
|
all_errors.append(error) |
|
|
|
|
|
|
|
|
max_error = max(error.max() for error in all_errors) |
|
|
min_error = min(error.min() for error in all_errors) |
|
|
|
|
|
for i in range(len(predictions)): |
|
|
|
|
|
scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
error = all_errors[i] |
|
|
|
|
|
|
|
|
if max_error > min_error: |
|
|
error_normalized = (error - min_error) / (max_error - min_error) |
|
|
else: |
|
|
error_normalized = error |
|
|
|
|
|
|
|
|
error_normalized = np.squeeze(error_normalized) |
|
|
if error_normalized.ndim == 3: |
|
|
error_normalized = error_normalized[0] |
|
|
|
|
|
|
|
|
error_colored = cm.jet(error_normalized) |
|
|
error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
combined = np.concatenate([scene_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
from PIL import ImageDraw, ImageFont |
|
|
img_pil = Image.fromarray(combined) |
|
|
draw = ImageDraw.Draw(img_pil) |
|
|
|
|
|
|
|
|
try: |
|
|
font = ImageFont.load_default() |
|
|
except: |
|
|
font = None |
|
|
|
|
|
text = f"Error: {min_error:.3f} - {max_error:.3f}" |
|
|
draw.text((combined.shape[1] - 150, 10), text, fill=(255, 255, 255), font=font) |
|
|
|
|
|
frames.append(img_pil) |
|
|
|
|
|
|
|
|
frames[0].save( |
|
|
output_path, |
|
|
save_all=True, |
|
|
append_images=frames[1:], |
|
|
duration=duration, |
|
|
loop=0 |
|
|
) |
|
|
|
|
|
print(f"GIF with error heatmap saved to {output_path}") |
|
|
print(f"Error range: {min_error:.4f} to {max_error:.4f}") |
|
|
|
|
|
|
|
|
def load_model_and_generate_video_with_metrics(checkpoint_path, dataset, device, |
|
|
output_path="amodal_completion.mp4", fps=8): |
|
|
"""Load trained model and generate video with metrics calculation""" |
|
|
import cv2 |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
metrics_calculator = VideoAmodalMetrics(device) |
|
|
|
|
|
|
|
|
model = Video3DUNet(in_channels=5, out_channels=3, sequence_length=8).to(device) |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['train_loss']:.4f}") |
|
|
|
|
|
|
|
|
sample = dataset[0] |
|
|
seq_len = 8 |
|
|
total_frames = len(sample['rgb_sequence']) |
|
|
|
|
|
print(f"Processing {total_frames} frames in windows of {seq_len}") |
|
|
|
|
|
all_predictions = [] |
|
|
all_rgb = [] |
|
|
all_modal_masks = [] |
|
|
all_amodal_masks = [] |
|
|
all_metrics = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
for start_idx in range(0, total_frames - seq_len + 1, seq_len//2): |
|
|
end_idx = min(start_idx + seq_len, total_frames) |
|
|
|
|
|
|
|
|
window_batch = {} |
|
|
for key, value in sample.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
if value.dim() == 4: |
|
|
window_batch[key] = value[start_idx:end_idx].unsqueeze(0) |
|
|
else: |
|
|
window_batch[key] = value.unsqueeze(0) |
|
|
else: |
|
|
window_batch[key] = [value] |
|
|
|
|
|
|
|
|
inputs = prepare_model_input(window_batch).to(device) |
|
|
pred = model(inputs) |
|
|
|
|
|
|
|
|
amodal_mask = window_batch['amodal_masks'].permute(0, 2, 1, 3, 4).expand_as(pred).to(device) |
|
|
pred_masked = pred * amodal_mask |
|
|
|
|
|
|
|
|
target = prepare_model_target(window_batch).to(device) |
|
|
window_metrics = metrics_calculator.calculate_all_metrics(pred, target, amodal_mask) |
|
|
all_metrics.append(window_metrics) |
|
|
|
|
|
|
|
|
pred_frames = pred_masked.squeeze(0).permute(1, 0, 2, 3).cpu() |
|
|
|
|
|
if start_idx == 0: |
|
|
all_predictions.extend([pred_frames[i] for i in range(len(pred_frames))]) |
|
|
else: |
|
|
overlap_frames = seq_len // 2 |
|
|
for i in range(overlap_frames): |
|
|
if len(all_predictions) > start_idx + i: |
|
|
all_predictions[start_idx + i] = (all_predictions[start_idx + i] + pred_frames[i]) / 2.0 |
|
|
|
|
|
for i in range(overlap_frames, len(pred_frames)): |
|
|
if start_idx + i < total_frames: |
|
|
all_predictions.append(pred_frames[i]) |
|
|
|
|
|
if start_idx == 0: |
|
|
all_rgb = [sample['rgb_sequence'][i] for i in range(total_frames)] |
|
|
all_modal_masks = [sample['modal_masks'][i] for i in range(total_frames)] |
|
|
all_amodal_masks = [sample['amodal_masks'][i] for i in range(total_frames)] |
|
|
all_gt_amodal = [sample['amodal_rgb_sequence'][i] for i in range(total_frames)] |
|
|
|
|
|
|
|
|
print("\nOverall Metrics:") |
|
|
avg_metrics = {} |
|
|
for key in all_metrics[0].keys(): |
|
|
avg_metrics[key] = np.mean([m[key] for m in all_metrics]) |
|
|
print(f" {key.upper()}: {avg_metrics[key]:.4f}") |
|
|
|
|
|
|
|
|
all_predictions = all_predictions[:total_frames] |
|
|
print(f"Generated {len(all_predictions)} prediction frames") |
|
|
|
|
|
|
|
|
height, width = all_predictions[0].shape[-2:] |
|
|
video_width = width * 4 |
|
|
video_height = height |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (video_width, video_height)) |
|
|
|
|
|
for i in range(len(all_predictions)): |
|
|
scene_rgb = all_rgb[i].permute(1, 2, 0).numpy() |
|
|
modal_mask = all_modal_masks[i][0].numpy() |
|
|
modal_mask_rgb = np.stack([modal_mask, modal_mask, modal_mask], axis=2) |
|
|
|
|
|
pred_rgb = all_predictions[i].permute(1, 2, 0).numpy() |
|
|
pred_rgb = np.clip(pred_rgb, 0, 1) |
|
|
|
|
|
try: |
|
|
gt_amodal = sample['amodal_rgb_sequence'][i].permute(1, 2, 0).numpy() |
|
|
amodal_mask_np = all_amodal_masks[i][0].numpy() |
|
|
gt_amodal_masked = gt_amodal * amodal_mask_np[:, :, None] |
|
|
except: |
|
|
gt_amodal_masked = np.zeros_like(pred_rgb) |
|
|
|
|
|
combined_frame = np.concatenate([ |
|
|
scene_rgb, |
|
|
modal_mask_rgb, |
|
|
pred_rgb, |
|
|
gt_amodal_masked |
|
|
], axis=1) |
|
|
|
|
|
combined_frame_bgr = cv2.cvtColor((combined_frame * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
|
out.write(combined_frame_bgr) |
|
|
|
|
|
if i % 5 == 0: |
|
|
print(f"Processed frame {i+1}/{len(all_predictions)}") |
|
|
|
|
|
out.release() |
|
|
print(f"Video saved to {output_path}") |
|
|
|
|
|
return all_predictions, all_rgb, all_gt_amodal, all_amodal_masks, avg_metrics |
|
|
|
|
|
|
|
|
def run_enhanced_video_generation(): |
|
|
"""Run video generation with metrics and error visualization""" |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
dataset = VideoAmodalDataset( |
|
|
root_dir='data', |
|
|
split='test', |
|
|
seq_len=24, |
|
|
img_size=(256, 256), |
|
|
max_scenes=1, |
|
|
samples_per_scene=1, |
|
|
max_samples=1 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_path = "video_amodal_model_epoch_4.pth" |
|
|
predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( |
|
|
checkpoint_path, |
|
|
dataset, |
|
|
device, |
|
|
output_path="amodal_completion_video_with_metrics.mp4", |
|
|
fps=8 |
|
|
) |
|
|
|
|
|
|
|
|
create_gif_with_error_heatmap( |
|
|
predictions, |
|
|
rgb_frames, |
|
|
gt_amodal_frames, |
|
|
amodal_masks, |
|
|
output_path="amodal_completion_with_error.gif", |
|
|
duration=150 |
|
|
) |
|
|
|
|
|
print("Enhanced video generation complete!") |
|
|
return metrics |
|
|
|
|
|
train_video_amodal_with_metrics() |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def run_gif_generation(): |
|
|
"""Simple function to generate GIFs from your trained model""" |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
dataset = VideoAmodalDataset( |
|
|
root_dir='data', |
|
|
split='test', |
|
|
seq_len=24, |
|
|
img_size=(256, 256), |
|
|
max_scenes=50, |
|
|
samples_per_scene=5, |
|
|
max_samples=50 |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint_path = "epoch_29.pth" |
|
|
|
|
|
predictions, rgb_frames, gt_amodal_frames, amodal_masks, metrics = load_model_and_generate_video_with_metrics( |
|
|
checkpoint_path, |
|
|
dataset, |
|
|
device, |
|
|
output_path="amodal_completion_video.mp4", |
|
|
fps=6 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
create_gif_with_error_heatmap( |
|
|
predictions, |
|
|
rgb_frames, |
|
|
gt_amodal_frames, |
|
|
amodal_masks, |
|
|
output_path="amodal_completion_with_error.gif", |
|
|
duration=150 |
|
|
) |
|
|
|
|
|
|
|
|
print("GIF creation complete!") |
|
|
print(f"Metrics: {metrics}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
run_gif_generation() |
|
|
|
|
|
import cv2 |
|
|
|
|
|
def draw_amodal_boundary(rgb_image, amodal_mask, color=(255, 0, 255)): |
|
|
contours, _ = cv2.findContours(amodal_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
outlined = rgb_image.copy() |
|
|
cv2.drawContours(outlined, contours, -1, color, thickness=2) |
|
|
return outlined |
|
|
|
|
|
|
|
|
def create_gif_with_error_heatmap(predictions, rgb_frames, gt_amodal_frames, amodal_masks, |
|
|
output_path="amodal_completion_with_error.gif", duration=240): |
|
|
"""Create animated GIF with proper error heatmap and colorbar""" |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.cm as cm |
|
|
from matplotlib.colors import Normalize |
|
|
import io |
|
|
|
|
|
frames = [] |
|
|
all_errors = [] |
|
|
|
|
|
|
|
|
for i in range(len(predictions)): |
|
|
pred_tensor = predictions[i] |
|
|
gt_tensor = gt_amodal_frames[i] |
|
|
mask_tensor = amodal_masks[i] if amodal_masks else None |
|
|
|
|
|
error = create_error_heatmap(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0), |
|
|
mask_tensor.unsqueeze(0) if mask_tensor is not None else None) |
|
|
all_errors.append(error) |
|
|
|
|
|
|
|
|
|
|
|
masked_errors = [] |
|
|
for i, error in enumerate(all_errors): |
|
|
if amodal_masks is not None: |
|
|
mask = amodal_masks[i][0].numpy() |
|
|
masked_error = error * mask |
|
|
masked_errors.extend(masked_error[masked_error > 0]) |
|
|
else: |
|
|
masked_errors.extend(error.flatten()) |
|
|
|
|
|
if masked_errors: |
|
|
|
|
|
min_error = np.percentile(masked_errors, 5) |
|
|
max_error = np.percentile(masked_errors, 95) |
|
|
else: |
|
|
min_error = min(error.min() for error in all_errors) |
|
|
max_error = max(error.max() for error in all_errors) |
|
|
|
|
|
|
|
|
if max_error - min_error < 1e-6: |
|
|
max_error = min_error + 1e-6 |
|
|
|
|
|
print(f"Error range for visualization: {min_error:.4f} to {max_error:.4f}") |
|
|
|
|
|
|
|
|
def create_colorbar(height=256, width=30): |
|
|
|
|
|
gradient = np.linspace(1, 0, height).reshape(-1, 1) |
|
|
gradient = np.repeat(gradient, width, axis=1) |
|
|
|
|
|
|
|
|
cmap = cm.get_cmap('hot') |
|
|
colorbar_colored = cmap(gradient) |
|
|
colorbar_rgb = (colorbar_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
colorbar_img = Image.fromarray(colorbar_rgb) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(1, 4)) |
|
|
fig.patch.set_facecolor('black') |
|
|
ax.set_facecolor('black') |
|
|
|
|
|
|
|
|
norm = Normalize(vmin=min_error, vmax=max_error) |
|
|
sm = cm.ScalarMappable(norm=norm, cmap='hot') |
|
|
sm.set_array([]) |
|
|
|
|
|
cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=1.0) |
|
|
cbar.set_label('Prediction Error', color='white', fontsize=10) |
|
|
cbar.ax.tick_params(colors='white', labelsize=8) |
|
|
|
|
|
|
|
|
ax.remove() |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', bbox_inches='tight', |
|
|
facecolor='black', edgecolor='none', dpi=100) |
|
|
buf.seek(0) |
|
|
colorbar_with_labels = Image.open(buf) |
|
|
plt.close() |
|
|
|
|
|
return colorbar_with_labels |
|
|
|
|
|
|
|
|
colorbar_img = create_colorbar() |
|
|
colorbar_width = colorbar_img.width |
|
|
|
|
|
for i in range(len(predictions)): |
|
|
|
|
|
scene_rgb = (rgb_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
pred_rgb = (np.clip(predictions[i].permute(1, 2, 0).numpy(), 0, 1) * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
gt_rgb = (gt_amodal_frames[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
error = all_errors[i] |
|
|
|
|
|
|
|
|
if amodal_masks is not None: |
|
|
mask = amodal_masks[i][0].numpy() |
|
|
error = error * mask |
|
|
|
|
|
|
|
|
error = np.squeeze(error) |
|
|
if error.ndim == 3: |
|
|
error = error[0] |
|
|
|
|
|
|
|
|
error_normalized = np.clip((error - min_error) / (max_error - min_error), 0, 1) |
|
|
|
|
|
|
|
|
cmap = cm.get_cmap('hot') |
|
|
error_colored = cmap(error_normalized) |
|
|
error_rgb = (error_colored[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
if amodal_masks is not None: |
|
|
mask_3d = np.stack([mask, mask, mask], axis=2) |
|
|
error_rgb = error_rgb * mask_3d.astype(np.uint8) |
|
|
|
|
|
|
|
|
highlighted_rgb = draw_amodal_boundary(scene_rgb, amodal_masks[i][0].cpu().numpy()) |
|
|
|
|
|
|
|
|
combined = np.concatenate([highlighted_rgb, pred_rgb, gt_rgb, error_rgb], axis=1) |
|
|
|
|
|
|
|
|
img_pil = Image.fromarray(combined) |
|
|
|
|
|
|
|
|
colorbar_resized = colorbar_img.resize((colorbar_width, img_pil.height)) |
|
|
|
|
|
|
|
|
final_width = img_pil.width + colorbar_width + 10 |
|
|
final_img = Image.new('RGB', (final_width, img_pil.height), color='black') |
|
|
|
|
|
|
|
|
final_img.paste(img_pil, (0, 0)) |
|
|
final_img.paste(colorbar_resized, (img_pil.width + 10, 0)) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(final_img) |
|
|
try: |
|
|
font = ImageFont.load_default() |
|
|
except: |
|
|
font = None |
|
|
|
|
|
frame_text = f"Frame {i+1}/{len(predictions)}" |
|
|
draw.text((10, 10), frame_text, fill=(0, 0, 0), font=font) |
|
|
|
|
|
frames.append(final_img) |
|
|
|
|
|
|
|
|
frames[0].save( |
|
|
output_path, |
|
|
save_all=True, |
|
|
append_images=frames[1:], |
|
|
duration=duration, |
|
|
loop=0 |
|
|
) |
|
|
|
|
|
print(f"GIF with proper error heatmap saved to {output_path}") |
|
|
print(f"Error range: {min_error:.4f} to {max_error:.4f}") |
|
|
print(f"Colorbar shows errors from low (black/red) to high (yellow/white)") |
|
|
|
|
|
|
|
|
def create_error_heatmap(pred, target, mask=None): |
|
|
"""Create error heatmap between prediction and target with enhanced sensitivity""" |
|
|
|
|
|
error = torch.sqrt(torch.sum((pred - target) ** 2, dim=1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
error = error * mask.squeeze() |
|
|
|
|
|
return error.cpu().numpy() |
|
|
|
|
|
|