File size: 3,985 Bytes
2fe55e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from pathlib import Path

import cv2
import pytest
import torch
from PIL import Image

from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import HDStrategy
from lama_cleaner.tests.test_model import get_config, get_data

current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)


def assert_equal(
    model, config, gt_name,
    fx: float = 1, fy: float = 1,
    img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
    mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
    example_p=current_dir / "bunny.jpeg",
):
    img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)

    example_image = cv2.imread(str(example_p))
    example_image = cv2.cvtColor(example_image, cv2.COLOR_BGRA2RGB)
    example_image = cv2.resize(example_image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)

    print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
    config.paint_by_example_example_image = Image.fromarray(example_image)
    res = model(img, mask, config)
    cv2.imwrite(str(save_dir / gt_name), res)


@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example(strategy):
    model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
    cfg = get_config(strategy, paint_by_example_steps=30)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_{strategy.capitalize()}.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
        fy=0.9,
        fx=1.3,
    )


@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_disable_nsfw(strategy):
    model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
    cfg = get_config(strategy, paint_by_example_steps=30)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
    )


@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_sd_scale(strategy):
    model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
    cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_{strategy.capitalize()}_sdscale.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
        fy=0.9,
        fx=1.3
    )


@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_cpu_offload(strategy):
    model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
    cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
    )


@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_cpu_offload_cpu_device(strategy):
    model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
    cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
    assert_equal(
        model,
        cfg,
        f"paint_by_example_{strategy.capitalize()}_cpu_offload_cpu_device.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
        fy=0.9,
        fx=1.3
    )