| | |
| | """final1.2.ipynb |
| | |
| | Automatically generated by Colab. |
| | |
| | Original file is located at |
| | https://colab.research.google.com/drive/1v6-6x7lqt6gr9VIauNVHIwjvIkewk8eT |
| | """ |
| |
|
| |
|
| |
|
| | """## FINAL 1.2""" |
| |
|
| |
|
| |
|
| | pip install torchmetrics lpips |
| |
|
| | |
| | import torch |
| | from torch import nn |
| | from torchvision.transforms import ToPILImage, ToTensor |
| | from torchvision.utils import make_grid |
| | from torchvision.io import write_video |
| |
|
| | |
| | from pathlib import Path |
| | from PIL import Image |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import random |
| | import json |
| | from IPython.display import Video |
| |
|
| | |
| | tensor_to_image = ToPILImage() |
| | image_to_tensor = ToTensor() |
| |
|
| | def get_img_dict(img_dir): |
| | img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')] |
| | img_files.sort() |
| |
|
| | img_dict = {} |
| | for img_file in img_files: |
| | img_type = img_file.name.split('_')[0] |
| | if img_type not in img_dict: |
| | img_dict[img_type] = [] |
| | img_dict[img_type].append(img_file) |
| | return img_dict |
| |
|
| |
|
| | def get_sample_dict(sample_dir): |
| |
|
| | camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name] |
| | camera_dirs.sort() |
| |
|
| | sample_dict = {} |
| |
|
| | for cam_dir in camera_dirs: |
| | cam_dict = {} |
| | cam_dict['scene'] = get_img_dict(cam_dir) |
| |
|
| | obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name] |
| | obj_dirs.sort() |
| |
|
| | for obj_dir in obj_dirs: |
| | cam_dict[obj_dir.name] = get_img_dict(obj_dir) |
| |
|
| | sample_dict[cam_dir.name] = cam_dict |
| |
|
| | return sample_dict |
| |
|
| | !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/test_obj_descriptors.json |
| | |
| | !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/train_obj_descriptors.json |
| | !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/ex_vis.mp4 |
| | !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/README.md |
| | !wget "https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/Notice%201%20-%20Unlimited_datasets.pdf" |
| | !wget https://huggingface.co/datasets/Amar-S/MOVi-MC-AC/resolve/main/.gitattributes |
| | |
| | from huggingface_hub import HfApi, hf_hub_download |
| | import random, os |
| | api = HfApi() |
| | repo_id = "Amar-S/MOVi-MC-AC" |
| | |
| | files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
| | |
| | train_files = [f for f in files if f.startswith("train/") and not f.endswith(".json")] |
| | test_files = [f for f in files if f.startswith("test/") and not f.endswith(".json")] |
| | print(f"Found {len(train_files)} train files and {len(test_files)} test files.") |
| | |
| | import os |
| | import random |
| | import shutil |
| | from huggingface_hub import hf_hub_download |
| | os.makedirs("/content/data/train", exist_ok=True) |
| | os.makedirs("/content/data/test", exist_ok=True) |
| | |
| | subset_train = random.sample(train_files, int(len(train_files) * 0.005)) |
| | subset_test = random.sample(test_files, int(len(test_files) * 0.005)) |
| | |
| | for file in subset_train: |
| | out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
| | dest_path = f"/content/data/train/{os.path.basename(file)}" |
| | shutil.copyfile(out_path, dest_path) |
| | |
| | for file in subset_test: |
| | out_path = hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=file) |
| | dest_path = f"/content/data/test/{os.path.basename(file)}" |
| | shutil.copyfile(out_path, dest_path) |
| |
|
| | import os |
| |
|
| | |
| | train_dir = "data/train" |
| | for file in os.listdir(train_dir): |
| | if file.endswith(".tar.gz"): |
| | filepath = os.path.join(train_dir, file) |
| | !tar -xzf {filepath} -C {train_dir} |
| |
|
| | |
| | test_dir = "data/test" |
| | for file in os.listdir(test_dir): |
| | if file.endswith(".tar.gz"): |
| | filepath = os.path.join(test_dir, file) |
| | !tar -xzf {filepath} -C {test_dir} |
| |
|
| |
|
| |
|
| | import os |
| | from pathlib import Path |
| | root = Path('/content/data') |
| | deleted = 0 |
| | for archive in root.rglob('*.tar.gz'): |
| | try: |
| | archive.unlink() |
| | print(f"Deleted {archive}") |
| | deleted += 1 |
| | except Exception as e: |
| | print(f"Error deleting {archive}: {e}") |
| | print(f"Total deleted: {deleted}") |
| |
|
| | pip install torchmetrics lpips |
| |
|
| | import matplotlib.pyplot as plt |
| | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| | import lpips |
| | import matplotlib.pyplot as plt |
| | import torch |
| |
|
| | def visualize_results(model, dataloader, device, num_samples=8): |
| | """Visualize results with properly masked output (no background)""" |
| | model.eval() |
| | samples_shown = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | if samples_shown >= num_samples: |
| | break |
| |
|
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | for i in range(rgb.shape[0]): |
| | if samples_shown >= num_samples: |
| | break |
| |
|
| | fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
| |
|
| | |
| | axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
| | axes[0].set_title('Scene RGB') |
| | axes[0].axis('off') |
| |
|
| | |
| | axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
| | axes[1].set_title('Amodal Mask') |
| | axes[1].axis('off') |
| |
|
| | |
| | axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
| | axes[2].set_title('Modal Mask') |
| | axes[2].axis('off') |
| |
|
| | |
| | axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
| | axes[3].set_title('GT Amodal RGB') |
| | axes[3].axis('off') |
| |
|
| | |
| | axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
| | axes[4].set_title('Predicted Amodal RGB') |
| | axes[4].axis('off') |
| |
|
| | |
| | diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
| | im = axes[5].imshow(diff.cpu(), cmap='hot') |
| | axes[5].set_title('Prediction Error') |
| | axes[5].axis('off') |
| | plt.colorbar(im, ax=axes[5]) |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| |
|
| | samples_shown += 1 |
| |
|
| |
|
| |
|
| | |
| | def evaluate_metrics(model, dataloader, device): |
| | """Compute evaluation metrics only within object regions""" |
| | model.eval() |
| | total_mse = 0 |
| | occluded_mse = 0 |
| | visible_mse = 0 |
| | total_pixels = 0 |
| | occluded_pixels = 0 |
| | visible_pixels = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | occluded_mask = batch['occluded_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | |
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | |
| | object_pixels = amodal_mask.sum() |
| | if object_pixels > 0: |
| | mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
| | total_mse += mse.item() |
| | total_pixels += object_pixels.item() |
| |
|
| | |
| | occluded_region = occluded_mask * amodal_mask |
| | occ_pixels = occluded_region.sum() |
| | if occ_pixels > 0: |
| | occ_mse = F.mse_loss(pred_masked * occluded_region, |
| | gt_masked * occluded_region, reduction='sum') |
| | occluded_mse += occ_mse.item() |
| | occluded_pixels += occ_pixels.item() |
| |
|
| | |
| | visible_region = modal_mask * amodal_mask |
| | vis_pixels = visible_region.sum() |
| | if vis_pixels > 0: |
| | vis_mse = F.mse_loss(pred_masked * visible_region, |
| | gt_masked * visible_region, reduction='sum') |
| | visible_mse += vis_mse.item() |
| | visible_pixels += vis_pixels.item() |
| |
|
| | return { |
| | 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
| | 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
| | 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
| | } |
| |
|
| |
|
| |
|
| | def calculate_metrics(model, dataloader, device): |
| | """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
| |
|
| | model.eval() |
| | psnr = PeakSignalNoiseRatio().to(device) |
| | ssim = StructuralSimilarityIndexMeasure().to(device) |
| | lpips_loss = lpips.LPIPS(net='alex').to(device) |
| |
|
| | total_psnr, total_ssim, total_lpips = 0, 0, 0 |
| | total_iou = 0 |
| | count = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | for i in range(pred.shape[0]): |
| | pred_i = pred_masked[i].unsqueeze(0) |
| | gt_i = gt_masked[i].unsqueeze(0) |
| |
|
| | |
| | if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
| | continue |
| |
|
| | total_psnr += psnr(pred_i, gt_i).item() |
| | total_ssim += ssim(pred_i, gt_i).item() |
| | total_lpips += lpips_loss(pred_i, gt_i).item() |
| |
|
| | |
| | intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
| | union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
| | if union > 0: |
| | iou = intersection.float() / union.float() |
| | total_iou += iou.item() |
| |
|
| | count += 1 |
| |
|
| | if count == 0: |
| | return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
| |
|
| | return { |
| | "psnr": total_psnr / count, |
| | "ssim": total_ssim / count, |
| | "lpips": total_lpips / count, |
| | "miou": total_iou / count |
| | } |
| |
|
| | pip install torchmetrics lpips |
| |
|
| | import matplotlib.pyplot as plt |
| | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
| | import lpips |
| | import matplotlib.pyplot as plt |
| | import torch |
| |
|
| | def visualize_results(model, dataloader, device, num_samples=8): |
| | """Visualize results with properly masked output (no background)""" |
| | model.eval() |
| | samples_shown = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | if samples_shown >= num_samples: |
| | break |
| |
|
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | for i in range(rgb.shape[0]): |
| | if samples_shown >= num_samples: |
| | break |
| |
|
| | fig, axes = plt.subplots(1, 6, figsize=(24, 4)) |
| |
|
| | |
| | axes[0].imshow(rgb[i].cpu().permute(1, 2, 0)) |
| | axes[0].set_title('Scene RGB') |
| | axes[0].axis('off') |
| |
|
| | |
| | axes[1].imshow(amodal_mask[i, 0].cpu(), cmap='gray') |
| | axes[1].set_title('Amodal Mask') |
| | axes[1].axis('off') |
| |
|
| | |
| | axes[2].imshow(modal_mask[i, 0].cpu(), cmap='gray') |
| | axes[2].set_title('Modal Mask') |
| | axes[2].axis('off') |
| |
|
| | |
| | axes[3].imshow(gt_masked[i].cpu().permute(1, 2, 0)) |
| | axes[3].set_title('GT Amodal RGB') |
| | axes[3].axis('off') |
| |
|
| | |
| | axes[4].imshow(pred_masked[i].cpu().permute(1, 2, 0)) |
| | axes[4].set_title('Predicted Amodal RGB') |
| | axes[4].axis('off') |
| |
|
| | |
| | diff = torch.abs(pred_masked[i] - gt_masked[i]).mean(dim=0) |
| | im = axes[5].imshow(diff.cpu(), cmap='hot') |
| | axes[5].set_title('Prediction Error') |
| | axes[5].axis('off') |
| | plt.colorbar(im, ax=axes[5]) |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| |
|
| | samples_shown += 1 |
| |
|
| |
|
| | def evaluate_metrics(model, dataloader, device): |
| | """Compute evaluation metrics only within object regions""" |
| | model.eval() |
| | total_mse = 0 |
| | occluded_mse = 0 |
| | visible_mse = 0 |
| | total_pixels = 0 |
| | occluded_pixels = 0 |
| | visible_pixels = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | occluded_mask = batch['occluded_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | |
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | |
| | object_pixels = amodal_mask.sum() |
| | if object_pixels > 0: |
| | mse = F.mse_loss(pred_masked, gt_masked, reduction='sum') |
| | total_mse += mse.item() |
| | total_pixels += object_pixels.item() |
| |
|
| | |
| | occluded_region = occluded_mask * amodal_mask |
| | occ_pixels = occluded_region.sum() |
| | if occ_pixels > 0: |
| | occ_mse = F.mse_loss(pred_masked * occluded_region, |
| | gt_masked * occluded_region, reduction='sum') |
| | occluded_mse += occ_mse.item() |
| | occluded_pixels += occ_pixels.item() |
| |
|
| | |
| | visible_region = modal_mask * amodal_mask |
| | vis_pixels = visible_region.sum() |
| | if vis_pixels > 0: |
| | vis_mse = F.mse_loss(pred_masked * visible_region, |
| | gt_masked * visible_region, reduction='sum') |
| | visible_mse += vis_mse.item() |
| | visible_pixels += vis_pixels.item() |
| |
|
| | return { |
| | 'total_mse': total_mse / total_pixels if total_pixels > 0 else 0, |
| | 'occluded_mse': occluded_mse / occluded_pixels if occluded_pixels > 0 else 0, |
| | 'visible_mse': visible_mse / visible_pixels if visible_pixels > 0 else 0, |
| | } |
| |
|
| |
|
| |
|
| | def calculate_metrics(model, dataloader, device): |
| | """Computes PSNR, SSIM, LPIPS, and IoU between predictions and GT amodal RGBs.""" |
| |
|
| | model.eval() |
| | psnr = PeakSignalNoiseRatio().to(device) |
| | ssim = StructuralSimilarityIndexMeasure().to(device) |
| | lpips_loss = lpips.LPIPS(net='alex').to(device) |
| |
|
| | total_psnr, total_ssim, total_lpips = 0, 0, 0 |
| | total_iou = 0 |
| | count = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| | pred = model(input_tensor) |
| |
|
| | pred_masked = pred * amodal_mask |
| | gt_masked = gt_amodal_rgb * amodal_mask |
| |
|
| | for i in range(pred.shape[0]): |
| | pred_i = pred_masked[i].unsqueeze(0) |
| | gt_i = gt_masked[i].unsqueeze(0) |
| |
|
| | |
| | if pred_i.shape[-1] < 64 or pred_i.shape[-2] < 64: |
| | continue |
| |
|
| | total_psnr += psnr(pred_i, gt_i).item() |
| | total_ssim += ssim(pred_i, gt_i).item() |
| | total_lpips += lpips_loss(pred_i, gt_i).item() |
| |
|
| | |
| | intersection = (amodal_mask[i] * (pred[i] > 0.5)).sum() |
| | union = ((amodal_mask[i] + (pred[i] > 0.5)) > 0).sum() |
| | if union > 0: |
| | iou = intersection.float() / union.float() |
| | total_iou += iou.item() |
| |
|
| | count += 1 |
| |
|
| | if count == 0: |
| | return {"psnr": 0, "ssim": 0, "lpips": 0, "miou": 0} |
| |
|
| | return { |
| | "psnr": total_psnr / count, |
| | "ssim": total_ssim / count, |
| | "lpips": total_lpips / count, |
| | "miou": total_iou / count |
| | } |
| |
|
| |
|
| |
|
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision import transforms |
| | from pathlib import Path |
| | from PIL import Image, ImageChops |
| | import numpy as np |
| |
|
| | class ModalAmodalDataset(Dataset): |
| | def __init__(self, root_dir, split, img_size=(128, 128), max_samples=None, val_split=0.2, use_val_from_train=False): |
| | self.root_dir = Path(root_dir) |
| | self.img_size = img_size |
| | self.max_samples = max_samples |
| | self.val_split = val_split |
| | self.use_val_from_train = use_val_from_train |
| | self.split = split |
| |
|
| | if split == 'val' and use_val_from_train: |
| | |
| | self.root_dir = self.root_dir / 'train' |
| | else: |
| | self.root_dir = self.root_dir / split |
| |
|
| | self.samples = self._build_sample_index() |
| |
|
| | self.rgb_transform = transforms.Compose([ |
| | transforms.Resize(img_size), |
| | transforms.ToTensor(), |
| | ]) |
| | self.mask_transform = transforms.Compose([ |
| | transforms.Resize(img_size), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | def _build_sample_index(self): |
| | samples = [] |
| | for scene_dir in self.root_dir.iterdir(): |
| | if not scene_dir.is_dir(): |
| | continue |
| | for camera_dir in scene_dir.iterdir(): |
| | if not camera_dir.name.startswith('camera_'): |
| | continue |
| |
|
| | rgba_paths = sorted(camera_dir.glob('rgba_*.png')) |
| | seg_paths = sorted(camera_dir.glob('segmentation_*.png')) |
| |
|
| | for obj_dir in camera_dir.iterdir(): |
| | if not obj_dir.name.startswith('obj_'): |
| | continue |
| |
|
| | amodal_paths = sorted(obj_dir.glob('segmentation_*.png')) |
| | amodal_rgb_paths = sorted(obj_dir.glob('rgba_*.png')) |
| |
|
| | if not (len(rgba_paths) == len(seg_paths) == len(amodal_paths) == len(amodal_rgb_paths)): |
| | continue |
| |
|
| | for rgba_path, seg_path, amodal_path, amodal_rgb_path in zip( |
| | rgba_paths, seg_paths, amodal_paths, amodal_rgb_paths |
| | ): |
| | samples.append({ |
| | 'rgb_path': rgba_path, |
| | 'seg_path': seg_path, |
| | 'amodal_path': amodal_path, |
| | 'amodal_rgb_path': amodal_rgb_path, |
| | 'object_id': int(obj_dir.name.split('_')[1]), |
| | 'scene': scene_dir.name, |
| | 'camera': camera_dir.name |
| | }) |
| |
|
| | |
| | if self.max_samples is not None and len(samples) > self.max_samples: |
| | |
| | import random |
| | random.seed(42) |
| | samples = random.sample(samples, self.max_samples) |
| | print(f"Dataset limited to {len(samples)} samples") |
| |
|
| | |
| | if self.use_val_from_train: |
| | import random |
| | random.seed(42) |
| | random.shuffle(samples) |
| |
|
| | val_size = int(len(samples) * self.val_split) |
| | if self.split == 'train': |
| | samples = samples[val_size:] |
| | print(f"Train split: {len(samples)} samples") |
| | elif self.split == 'val': |
| | samples = samples[:val_size] |
| | print(f"Validation split: {len(samples)} samples") |
| |
|
| | return samples |
| |
|
| | def __len__(self): |
| | return len(self.samples) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.samples[idx] |
| |
|
| | |
| | rgb = Image.open(sample['rgb_path']).convert('RGB') |
| | seg_map = np.array(Image.open(sample['seg_path'])) |
| | amodal_mask_img = Image.open(sample['amodal_path']).convert('L') |
| | amodal_rgb = Image.open(sample['amodal_rgb_path']).convert('RGB') |
| |
|
| | |
| | modal_mask_np = (seg_map == sample['object_id']).astype(np.uint8) * 255 |
| | modal_mask_img = Image.fromarray(modal_mask_np, mode='L') |
| |
|
| | |
| | rgb = self.rgb_transform(rgb) |
| | modal_mask = self.mask_transform(modal_mask_img) |
| | amodal_mask = self.mask_transform(amodal_mask_img) |
| | amodal_rgb = self.rgb_transform(amodal_rgb) |
| |
|
| | |
| | occluded_mask = amodal_mask - modal_mask |
| | occluded_mask = torch.clamp(occluded_mask, 0, 1) |
| |
|
| | return { |
| | 'rgb': rgb, |
| | 'modal_mask': modal_mask, |
| | 'amodal_mask': amodal_mask, |
| | 'occluded_mask': occluded_mask, |
| | 'amodal_rgb': amodal_rgb, |
| | } |
| |
|
| |
|
| | class ImprovedUNet(nn.Module): |
| |
|
| | def __init__(self, in_channels=5, out_channels=3): |
| | super().__init__() |
| |
|
| | def conv_block(in_ch, out_ch, dropout=0.1): |
| | return nn.Sequential( |
| | nn.Conv2d(in_ch, out_ch, 3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU(inplace=True), |
| | nn.Dropout2d(dropout), |
| | nn.Conv2d(out_ch, out_ch, 3, padding=1), |
| | nn.BatchNorm2d(out_ch), |
| | nn.ReLU(inplace=True) |
| | ) |
| |
|
| | |
| | self.down1 = conv_block(in_channels, 64) |
| | self.pool1 = nn.MaxPool2d(2) |
| | self.down2 = conv_block(64, 128) |
| | self.pool2 = nn.MaxPool2d(2) |
| | self.down3 = conv_block(128, 256) |
| | self.pool3 = nn.MaxPool2d(2) |
| | self.down4 = conv_block(256, 512) |
| | self.pool4 = nn.MaxPool2d(2) |
| |
|
| | |
| | self.middle = conv_block(512, 1024, dropout=0.2) |
| |
|
| | |
| | self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2) |
| | self.up_block1 = conv_block(1024, 512) |
| | self.up2 = nn.ConvTranspose2d(512, 256, 2, stride=2) |
| | self.up_block2 = conv_block(512, 256) |
| | self.up3 = nn.ConvTranspose2d(256, 128, 2, stride=2) |
| | self.up_block3 = conv_block(256, 128) |
| | self.up4 = nn.ConvTranspose2d(128, 64, 2, stride=2) |
| | self.up_block4 = conv_block(128, 64) |
| |
|
| | self.final = nn.Conv2d(64, out_channels, 1) |
| |
|
| | def forward(self, x): |
| | |
| | d1 = self.down1(x) |
| | d2 = self.down2(self.pool1(d1)) |
| | d3 = self.down3(self.pool2(d2)) |
| | d4 = self.down4(self.pool3(d3)) |
| |
|
| | |
| | m = self.middle(self.pool4(d4)) |
| |
|
| | |
| | u1 = self.up_block1(torch.cat([self.up1(m), d4], dim=1)) |
| | u2 = self.up_block2(torch.cat([self.up2(u1), d3], dim=1)) |
| | u3 = self.up_block3(torch.cat([self.up3(u2), d2], dim=1)) |
| | u4 = self.up_block4(torch.cat([self.up4(u3), d1], dim=1)) |
| |
|
| | return torch.sigmoid(self.final(u4)) |
| |
|
| | class AmodalCompletionLoss(nn.Module): |
| | """Loss that only considers object regions (ignores background)""" |
| |
|
| | def __init__(self, occluded_weight=5.0, visible_weight=1.0): |
| | super().__init__() |
| | self.occluded_weight = occluded_weight |
| | self.visible_weight = visible_weight |
| | self.lpips_model = lpips.LPIPS(net='alex') |
| |
|
| | def forward(self, pred, target, modal_mask, occluded_mask, amodal_mask): |
| | |
| | device = pred.device |
| | self.lpips_model = self.lpips_model.to(device) |
| |
|
| | pred_masked = pred * amodal_mask |
| | target_masked = target * amodal_mask |
| |
|
| |
|
| |
|
| | |
| | visible_region = modal_mask * amodal_mask |
| | if visible_region.sum() > 0: |
| | visible_loss = F.mse_loss(pred_masked * visible_region, target_masked * visible_region) |
| | else: |
| | visible_loss = torch.tensor(0.0).to(pred.device) |
| |
|
| | |
| | occluded_region = occluded_mask * amodal_mask |
| | if occluded_region.sum() > 0: |
| | occluded_loss = F.mse_loss(pred_masked * occluded_region, target_masked * occluded_region) |
| | else: |
| | occluded_loss = torch.tensor(0.0).to(pred.device) |
| |
|
| |
|
| | perceptual_loss = self.lpips_model(pred_masked, target_masked).mean() |
| |
|
| | |
| | boundary_mask = F.conv2d(amodal_mask, torch.ones(1,1,3,3).to(amodal_mask.device), padding=1) |
| | boundary_mask = ((boundary_mask > 0) & (boundary_mask < 9)).float() |
| | boundary_loss = F.mse_loss(pred_masked * boundary_mask, target_masked * boundary_mask) |
| |
|
| | total_loss = (self.visible_weight * visible_loss + |
| | self.occluded_weight * occluded_loss + |
| | 2.0 * boundary_loss) |
| |
|
| | return total_loss, visible_loss, occluded_loss, boundary_loss |
| |
|
| |
|
| | def train_improved(model, dataloader, optimizer, device, num_epochs): |
| | model.train() |
| | criterion = AmodalCompletionLoss() |
| |
|
| | for epoch in range(num_epochs): |
| | total_loss = 0 |
| | for i, batch in enumerate(dataloader): |
| | rgb = batch['rgb'].to(device) |
| | modal_mask = batch['modal_mask'].to(device) |
| | amodal_mask = batch['amodal_mask'].to(device) |
| | occluded_mask = batch['occluded_mask'].to(device) |
| | gt_amodal_rgb = batch['amodal_rgb'].to(device) |
| |
|
| | input_tensor = torch.cat([rgb, modal_mask, amodal_mask], dim=1) |
| |
|
| | optimizer.zero_grad() |
| | pred = model(input_tensor) |
| |
|
| | loss, vis_loss, occ_loss, boundary_loss = criterion( |
| | pred, gt_amodal_rgb, modal_mask, occluded_mask, amodal_mask |
| | ) |
| |
|
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| |
|
| | if i % 16 == 0: |
| | print(f"Epoch [{epoch}/{num_epochs}] [{i}/{len(dataloader)}] " |
| | f"Total: {loss.item():.4f}, Visible: {vis_loss.item():.4f}, " |
| | f"Occluded: {occ_loss.item():.4f}, Boundary: {boundary_loss.item():.4f}") |
| |
|
| | print(f"Epoch {epoch} Average Loss: {total_loss/len(dataloader):.4f}") |
| |
|
| | |
| | if __name__ == "__main__": |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | data_root = "data" |
| |
|
| | |
| | train_dataset = ModalAmodalDataset( |
| | root_dir=data_root, |
| | split='train', |
| | img_size=(128, 128), |
| | max_samples=1000, |
| | val_split=0.2, |
| | use_val_from_train=True |
| | ) |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=16, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| |
|
| | |
| | val_dataset = ModalAmodalDataset( |
| | root_dir=data_root, |
| | split='val', |
| | img_size=(128, 128), |
| | max_samples=1000, |
| | val_split=0.2, |
| | use_val_from_train=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=4, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True |
| | ) |
| |
|
| | print(f"Training on {len(train_dataset)} samples, {len(train_loader)} batches per epoch") |
| | print(f"Validation on {len(val_dataset)} samples, {len(val_loader)} batches") |
| |
|
| |
|
| |
|
| |
|
| | model = ImprovedUNet().to(device) |
| | model.load_state_dict(torch.load('amodal_completion_model.pth', map_location=device)) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | model = model.to(device) |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) |
| |
|
| | |
| | |
| |
|
| | |
| | print("\n" + "="*50) |
| | print("EVALUATION RESULTS") |
| | print("="*50) |
| |
|
| | |
| | metrics = evaluate_metrics(model, val_loader, device) |
| | print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| | print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| | print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| | print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
| |
|
| | |
| | print("\nGenerating visualizations...") |
| | visualize_results(model, val_loader, device, num_samples=8) |
| |
|
| | |
| | image_metrics = calculate_metrics(model, val_loader, device) |
| | print(f"PSNR: {image_metrics['psnr']:.4f}") |
| | print(f"SSIM: {image_metrics['ssim']:.4f}") |
| | print(f"LPIPS: {image_metrics['lpips']:.4f}") |
| | print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
| |
|
| | |
| | data_root = "data" |
| |
|
| | |
| | train_dataset = ModalAmodalDataset( |
| | root_dir=data_root, |
| | split='train', |
| | img_size=(128, 128), |
| | max_samples=1000, |
| | val_split=0.2, |
| | use_val_from_train=True |
| | ) |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=16, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| |
|
| | |
| | val_dataset = ModalAmodalDataset( |
| | root_dir=data_root, |
| | split='val', |
| | img_size=(128, 128), |
| | max_samples=1000, |
| | val_split=0.2, |
| | use_val_from_train=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=4, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True |
| | ) |
| |
|
| | |
| | torch.save(model.state_dict(), 'amodal_completion_model.pth') |
| |
|
| | |
| |
|
| | test_dataset = ModalAmodalDataset( |
| | root_dir=data_root, |
| | split='test', |
| | img_size=(128, 128), |
| | max_samples=2000 |
| | ) |
| | test_loader = DataLoader( |
| | test_dataset, |
| | batch_size=8, |
| | shuffle=True, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| |
|
| | print("EVALUATION RESULTS") |
| | print("="*50) |
| |
|
| | |
| | metrics = evaluate_metrics(model, test_loader, device) |
| | print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| | print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| | print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| | print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
| |
|
| | |
| | print("\nGenerating visualizations...") |
| | visualize_results(model, test_loader, device, num_samples=16) |
| |
|
| | from google.colab import runtime |
| | runtime.unassign() |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = ImprovedUNet() |
| | torch.load('amodal_completion_model.pth', map_location=torch.device('cpu')) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | print("\n" + "="*50) |
| | print("EVALUATION RESULTS") |
| | print("="*50) |
| |
|
| | |
| | metrics = evaluate_metrics(model, val_loader, device) |
| | print(f"Overall MSE: {metrics['total_mse']:.6f}") |
| | print(f"Occluded Region MSE: {metrics['occluded_mse']:.6f}") |
| | print(f"Visible Region MSE: {metrics['visible_mse']:.6f}") |
| | print(f"Occluded/Visible MSE Ratio: {metrics['occluded_mse']/metrics['visible_mse']:.2f}") |
| |
|
| | |
| | print("\nGenerating visualizations...") |
| | visualize_results(model, val_loader, device, num_samples=8) |
| |
|
| | |
| | image_metrics = calculate_metrics(model, val_loader, device) |
| | print(f"PSNR: {image_metrics['psnr']:.4f}") |
| | print(f"SSIM: {image_metrics['ssim']:.4f}") |
| | print(f"LPIPS: {image_metrics['lpips']:.4f}") |
| | print(f"mIoU (pred vs GT): {image_metrics['miou']:.4f}") |
| |
|
| | model = ImprovedUNet() |
| | model.eval() |