lama / bin /evaluator_example.py
AK391
files
d380b77
import os
import cv2
import numpy as np
import torch
from skimage import io
from skimage.transform import resize
from torch.utils.data import Dataset
from saicinpainting.evaluation.evaluator import InpaintingEvaluator
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
class SimpleImageDataset(Dataset):
def __init__(self, root_dir, image_size=(400, 600)):
self.root_dir = root_dir
self.files = sorted(os.listdir(root_dir))
self.image_size = image_size
def __getitem__(self, index):
img_name = os.path.join(self.root_dir, self.files[index])
image = io.imread(img_name)
image = resize(image, self.image_size, anti_aliasing=True)
image = torch.FloatTensor(image).permute(2, 0, 1)
return image
def __len__(self):
return len(self.files)
def create_rectangle_mask(height, width):
mask = np.ones((height, width))
up_left_corner = width // 4, height // 4
down_right_corner = (width - up_left_corner[0] - 1, height - up_left_corner[1] - 1)
cv2.rectangle(mask, up_left_corner, down_right_corner, (0, 0, 0), thickness=cv2.FILLED)
return mask
class Model():
def __call__(self, img_batch, mask_batch):
mean = (img_batch * mask_batch[:, None, :, :]).sum(dim=(2, 3)) / mask_batch.sum(dim=(1, 2))[:, None]
inpainted = mean[:, :, None, None] * (1 - mask_batch[:, None, :, :]) + img_batch * mask_batch[:, None, :, :]
return inpainted
class SimpleImageSquareMaskDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.mask = torch.FloatTensor(create_rectangle_mask(*self.dataset.image_size))
self.model = Model()
def __getitem__(self, index):
img = self.dataset[index]
mask = self.mask.clone()
inpainted = self.model(img[None, ...], mask[None, ...])
return dict(image=img, mask=mask, inpainted=inpainted)
def __len__(self):
return len(self.dataset)
dataset = SimpleImageDataset('imgs')
mask_dataset = SimpleImageSquareMaskDataset(dataset)
model = Model()
metrics = {
'ssim': SSIMScore(),
'lpips': LPIPSScore(),
'fid': FIDScore()
}
evaluator = InpaintingEvaluator(
mask_dataset, scores=metrics, batch_size=3, area_grouping=True
)
results = evaluator.evaluate(model)
print(results)