File size: 2,965 Bytes
2fe55e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import hashlib
import os
import time

from lama_cleaner.plugins.anime_seg import AnimeSeg

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path

import cv2
import pytest
import torch.cuda

from lama_cleaner.plugins import (
    RemoveBG,
    RealESRGANUpscaler,
    GFPGANPlugin,
    RestoreFormerPlugin,
    InteractiveSeg,
)

current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "bunny.jpeg"
img_bytes = open(img_p, "rb").read()
bgr_img = cv2.imread(str(img_p))
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)


def _save(img, name):
    cv2.imwrite(str(save_dir / name), img)


def test_remove_bg():
    model = RemoveBG()
    res = model.forward(bgr_img)
    res = cv2.cvtColor(res, cv2.COLOR_RGBA2BGRA)
    _save(res, "test_remove_bg.png")


def test_anime_seg():
    model = AnimeSeg()
    img = cv2.imread(str(current_dir / "anime_test.png"))
    res = model.forward(img)
    assert len(res.shape) == 3
    assert res.shape[-1] == 4
    _save(res, "test_anime_seg.png")


@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_upscale(device):
    if device == "cuda" and not torch.cuda.is_available():
        return
    if device == "mps" and not torch.backends.mps.is_available():
        return

    model = RealESRGANUpscaler("realesr-general-x4v3", device)
    res = model.forward(bgr_img, 2)
    _save(res, f"test_upscale_x2_{device}.png")

    res = model.forward(bgr_img, 4)
    _save(res, f"test_upscale_x4_{device}.png")


@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_gfpgan(device):
    if device == "cuda" and not torch.cuda.is_available():
        return
    if device == "mps" and not torch.backends.mps.is_available():
        return
    model = GFPGANPlugin(device)
    res = model(rgb_img, None, None)
    _save(res, f"test_gfpgan_{device}.png")


@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_restoreformer(device):
    if device == "cuda" and not torch.cuda.is_available():
        return
    if device == "mps" and not torch.backends.mps.is_available():
        return
    model = RestoreFormerPlugin(device)
    res = model(rgb_img, None, None)
    _save(res, f"test_restoreformer_{device}.png")


@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_segment_anything(device):
    if device == "cuda" and not torch.cuda.is_available():
        return
    if device == "mps" and not torch.backends.mps.is_available():
        return
    img_md5 = hashlib.md5(img_bytes).hexdigest()
    model = InteractiveSeg("vit_l", device)
    new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)

    save_name = f"test_segment_anything_{device}.png"
    _save(new_mask, save_name)

    start = time.time()
    model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
    print(f"Time for {save_name}: {time.time() - start:.2f}s")