File size: 2,505 Bytes
3030eb7
 
eb7a233
3030eb7
 
 
 
acf5692
eb7a233
 
3030eb7
 
7d3d39b
3030eb7
 
 
 
7d3d39b
3030eb7
 
 
 
 
 
 
 
 
 
eb7a233
 
 
 
3030eb7
 
7d3d39b
eb7a233
3030eb7
 
 
 
 
7d3d39b
 
8cc2603
 
eb7a233
 
 
7d3d39b
eb7a233
3030eb7
eb7a233
 
 
 
 
 
 
3030eb7
7d3d39b
 
 
 
 
 
eb7a233
 
 
9ea4ecb
7d3d39b
eb7a233
 
 
 
7d3d39b
eb7a233
 
 
3030eb7
 
7d3d39b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
import tempfile

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def fn(image):
    if image is None:
        return None, None
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    input_images = transform_image(im).unsqueeze(0).to("cuda")
    # ์˜ˆ์ธก
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    im.putalpha(mask)
    return im, origin

def save_image(image):
    if image is None:
        return None
    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
        image.save(temp_file.name, format="PNG")
    return temp_file.name

def process_and_download(input_image):
    result, original = fn(input_image)
    if result is None:
        return None, None
    result_path = save_image(result)
    original_path = save_image(original)
    return [result_path, original_path], result_path

# ์˜ˆ์ œ ์ด๋ฏธ์ง€๋ฅผ ์ง์ ‘ PIL ๊ฐ์ฒด๋กœ ๋กœ๋“œ
example_image1 = Image.open("example_images/example1.png")
example_image2 = Image.open("example_images/example2.png")
example_image3 = Image.open("example_images/example3.png")

# ์ธํ„ฐํŽ˜์ด์Šค ์ปดํฌ๋„ŒํŠธ ์ •์˜
image = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
png_output = gr.File(label="PNG ๋‹ค์šด๋กœ๋“œ")

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
demo = gr.Interface(
    process_and_download,
    inputs=image,
    outputs=[slider, png_output],
    examples=[example_image1, example_image2, example_image3],
    title="๋ฐฐ๊ฒฝ ์ œ๊ฑฐ",
    description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด BiRefNet ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐฐ๊ฒฝ์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํŒŒ์ผ๋กœ ๋‹ค์šด๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
)

if __name__ == "__main__":
    demo.launch()