File size: 2,322 Bytes
2fe55e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from pathlib import Path

import pytest
import torch

from lama_cleaner.model_manager import ModelManager
from lama_cleaner.tests.test_model import get_config, assert_equal
from lama_cleaner.schema import HDStrategy

current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / 'result'
save_dir.mkdir(exist_ok=True, parents=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


@pytest.mark.parametrize("disable_nsfw", [True, False])
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_instruct_pix2pix(disable_nsfw, cpu_offload):
    sd_steps = 50 if device == 'cuda' else 1
    model = ModelManager(name="instruct_pix2pix",
                         device=torch.device(device),
                         hf_access_token="",
                         sd_run_local=False,
                         disable_nsfw=disable_nsfw,
                         sd_cpu_textencoder=False,
                         cpu_offload=cpu_offload)
    cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps, sd_scale=1.1)

    name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"

    assert_equal(
        model,
        cfg,
        f"instruct_pix2pix_{name}.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
        fx=1.3
    )


@pytest.mark.parametrize("disable_nsfw", [False])
@pytest.mark.parametrize("cpu_offload", [False])
def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
    sd_steps = 50 if device == 'cuda' else 1
    model = ModelManager(name="instruct_pix2pix",
                         device=torch.device(device),
                         hf_access_token="",
                         sd_run_local=False,
                         disable_nsfw=disable_nsfw,
                         sd_cpu_textencoder=False,
                         cpu_offload=cpu_offload)
    cfg = get_config(strategy=HDStrategy.ORIGINAL, prompt='What if it were snowing?', p2p_steps=sd_steps)

    name = f"snow"

    assert_equal(
        model,
        cfg,
        f"instruct_pix2pix_{name}.png",
        img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
        mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
    )