Spaces:
Building
on
A10G
Building
on
A10G
File size: 7,955 Bytes
eb9ca51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import datetime
import numpy as np
import os
from PIL import Image
import pytest
from pytest import fixture
from typing import Tuple, List
from cv2 import imread, cvtColor, COLOR_BGR2RGB
from skimage.metrics import structural_similarity as ssim
"""
This test suite compares images in 2 directories by file name
The directories are specified by the command line arguments --baseline_dir and --test_dir
"""
# ssim: Structural Similarity Index
# Returns a tuple of (ssim, diff_image)
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
score, diff = ssim(img0, img1, channel_axis=-1, full=True)
# rescale the difference image to 0-255 range
diff = (diff * 255).astype("uint8")
return score, diff
# Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95}
class TestCompareImageMetrics:
@fixture(scope="class")
def test_file_names(self, args_pytest):
test_dir = args_pytest['test_dir']
fnames = self.gather_file_basenames(test_dir)
yield fnames
del fnames
@fixture(scope="class", autouse=True)
def teardown(self, args_pytest):
yield
# Runs after all tests are complete
# Aggregate output files into a grid of images
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
img_output_dir = args_pytest['img_output_dir']
metrics_file = args_pytest['metrics_file']
grid_dir = os.path.join(img_output_dir, "grid")
os.makedirs(grid_dir, exist_ok=True)
for metric_dir in METRICS.keys():
metric_path = os.path.join(img_output_dir, metric_dir)
for file in os.listdir(metric_path):
if file.endswith(".png"):
score = self.lookup_score_from_fname(file, metrics_file)
image_file_list = []
image_file_list.append([
os.path.join(baseline_dir, file),
os.path.join(test_dir, file),
os.path.join(metric_path, file)
])
# Create grid
image_list = [[Image.open(file) for file in files] for files in image_file_list]
grid = self.image_grid(image_list)
grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
# Tests run for each baseline file name
@fixture()
def fname(self, baseline_fname):
yield baseline_fname
del baseline_fname
def test_directories_not_empty(self, args_pytest):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"
def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
# Check that all files in baseline_dir have a file in test_dir with matching metadata
baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
file_match = self.find_file_match(baseline_file_path, file_paths)
assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"
# For a baseline image file, finds the corresponding file name in test_dir and
# compares the images using the metrics in METRICS
@pytest.mark.parametrize("metric", METRICS.keys())
def test_pipeline_compare(
self,
args_pytest,
fname,
test_file_names,
metric,
):
baseline_dir = args_pytest['baseline_dir']
test_dir = args_pytest['test_dir']
metrics_output_file = args_pytest['metrics_file']
img_output_dir = args_pytest['img_output_dir']
baseline_file_path = os.path.join(baseline_dir, fname)
# Find file match
file_paths = [os.path.join(test_dir, f) for f in test_file_names]
test_file = self.find_file_match(baseline_file_path, file_paths)
# Run metrics
sample_baseline = self.read_img(baseline_file_path)
sample_secondary = self.read_img(test_file)
score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
metric_status = score > METRICS_PASS_THRESHOLD[metric]
# Save metric values
with open(metrics_output_file, 'a') as f:
run_info = os.path.splitext(fname)[0]
metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")
# Save metric image
metric_img_dir = os.path.join(img_output_dir, metric)
os.makedirs(metric_img_dir, exist_ok=True)
output_filename = f'{fname}'
Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))
assert score > METRICS_PASS_THRESHOLD[metric]
def read_img(self, filename: str) -> np.ndarray:
cvImg = imread(filename)
cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
return cvImg
def image_grid(self, img_list: list[list[Image.Image]]):
# imgs is a 2D list of images
# Assumes the input images are a rectangular grid of equal sized images
rows = len(img_list)
cols = len(img_list[0])
w, h = img_list[0][0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, row in enumerate(img_list):
for j, img in enumerate(row):
grid.paste(img, box=(j*w, i*h))
return grid
def lookup_score_from_fname(self,
fname: str,
metrics_output_file: str
) -> float:
fname_basestr = os.path.splitext(fname)[0]
with open(metrics_output_file, 'r') as f:
for line in f:
if fname_basestr in line:
score = float(line.split('|')[5])
return score
raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")
def gather_file_basenames(self, directory: str):
files = []
for file in os.listdir(directory):
if file.endswith(".png"):
files.append(file)
return files
def read_file_prompt(self, fname:str) -> str:
# Read prompt from image file metadata
img = Image.open(fname)
img.load()
return img.info['prompt']
def find_file_match(self, baseline_file: str, file_paths: List[str]):
# Find a file in file_paths with matching metadata to baseline_file
baseline_prompt = self.read_file_prompt(baseline_file)
# Do not match empty prompts
if baseline_prompt is None or baseline_prompt == "":
return None
# Find file match
# Reorder test_file_names so that the file with matching name is first
# This is an optimization because matching file names are more likely
# to have matching metadata if they were generated with the same script
basename = os.path.basename(baseline_file)
file_path_basenames = [os.path.basename(f) for f in file_paths]
if basename in file_path_basenames:
match_index = file_path_basenames.index(basename)
file_paths.insert(0, file_paths.pop(match_index))
for f in file_paths:
test_file_prompt = self.read_file_prompt(f)
if baseline_prompt == test_file_prompt:
return f |