File size: 1,656 Bytes
cbbdd92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import cv2
import pytest
from PIL import Image

from iopaint.model_manager import ModelManager
from iopaint.schema import HDStrategy
from iopaint.tests.utils import (
    current_dir,
    get_config,
    get_data,
    save_dir,
    check_device,
)

model_name = "Fantasy-Studio/Paint-by-Example"


def assert_equal(
    model,
    config,
    save_name: str,
    fx: float = 1,
    fy: float = 1,
    img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
    mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
    example_p=current_dir / "bunny.jpeg",
):
    img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)

    example_image = cv2.imread(str(example_p))
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
    example_image = cv2.resize(
        example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA
    )

    print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
    config.paint_by_example_example_image = Image.fromarray(example_image)
    res = model(img, mask, config)
    cv2.imwrite(str(save_dir / save_name), res)


@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
def test_paint_by_example(device):
    sd_steps = check_device(device)
    model = ModelManager(name=model_name, device=device, disable_nsfw=True)
    cfg = get_config(strategy=HDStrategy.ORIGINAL, sd_steps=sd_steps)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_device_{device}.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
        fy=0.9,
        fx=1.3,
    )