File size: 2,398 Bytes
3030eb7
 
eb7a233
3030eb7
 
 
 
acf5692
eb7a233
 
 
3030eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb7a233
 
 
 
3030eb7
 
eb7a233
 
 
3030eb7
 
 
 
 
eb7a233
 
8cc2603
 
eb7a233
 
 
 
 
3030eb7
eb7a233
 
 
 
 
 
 
3030eb7
eb7a233
 
 
9ea4ecb
eb7a233
 
 
 
 
9ea4ecb
eb7a233
 
 
 
 
 
 
 
3030eb7
 
eb7a233
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import io
import os
import tempfile

torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def fn(image):
    if image is None:
        return None, None
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    image = load_img(im)
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # ์˜ˆ์ธก
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image, origin

def save_image(image):
    if image is None:
        return None
    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
        image.save(temp_file, format="PNG")
    return temp_file.name

def process_and_download(input_image):
    result, original = fn(input_image)
    if result is None:
        return None, None
    result_path = save_image(result)
    original_path = save_image(original)
    return [result_path, original_path], result_path

image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")

examples = [
    os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ1.png"),
    os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ2.png"),
    os.path.join(os.path.dirname(__file__), "์˜ˆ์ œ3.png")
]

demo = gr.Interface(
    process_and_download,
    inputs=image,
    outputs=[slider, png_output],
    examples=examples,
    title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
    description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
)

if __name__ == "__main__":
    demo.launch()