Spaces:
Build error
Build error
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from skimage import io | |
from skimage.transform import resize | |
from torch.utils.data import Dataset | |
from saicinpainting.evaluation.evaluator import InpaintingEvaluator | |
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore | |
class SimpleImageDataset(Dataset): | |
def __init__(self, root_dir, image_size=(400, 600)): | |
self.root_dir = root_dir | |
self.files = sorted(os.listdir(root_dir)) | |
self.image_size = image_size | |
def __getitem__(self, index): | |
img_name = os.path.join(self.root_dir, self.files[index]) | |
image = io.imread(img_name) | |
image = resize(image, self.image_size, anti_aliasing=True) | |
image = torch.FloatTensor(image).permute(2, 0, 1) | |
return image | |
def __len__(self): | |
return len(self.files) | |
def create_rectangle_mask(height, width): | |
mask = np.ones((height, width)) | |
up_left_corner = width // 4, height // 4 | |
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1) | |
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED) | |
return mask | |
class Model(): | |
def __call__(self, img_batch, mask_batch): | |
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None] | |
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :] | |
return inpainted | |
class SimpleImageSquareMaskDataset(Dataset): | |
def __init__(self, dataset): | |
self.dataset = dataset | |
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size)) | |
self.model = Model() | |
def __getitem__(self, index): | |
img = self.dataset[index] | |
mask = self.mask.clone() | |
inpainted = self.model(img[None, ...], mask[None, ...]) | |
return dict(image=img, mask=mask, inpainted=inpainted) | |
def __len__(self): | |
return len(self.dataset) | |
dataset = SimpleImageDataset('imgs') | |
mask_dataset = SimpleImageSquareMaskDataset(dataset) | |
model = Model() | |
metrics = { | |
'ssim': SSIMScore(), | |
'lpips': LPIPSScore(), | |
'fid': FIDScore() | |
} | |
evaluator = InpaintingEvaluator( | |
mask_dataset, scores=metrics, batch_size=3, area_grouping=True | |
) | |
results = evaluator.evaluate(model) | |
print(results) | |