Spaces:
Runtime error
Runtime error
File size: 2,398 Bytes
3030eb7 eb7a233 3030eb7 acf5692 eb7a233 3030eb7 eb7a233 3030eb7 eb7a233 3030eb7 eb7a233 8cc2603 eb7a233 3030eb7 eb7a233 3030eb7 eb7a233 9ea4ecb eb7a233 9ea4ecb eb7a233 3030eb7 eb7a233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import io
import os
import tempfile
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
@spaces.GPU
def fn(image):
if image is None:
return None, None
im = load_img(image, output_type="pil")
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
image = load_img(im)
input_images = transform_image(image).unsqueeze(0).to("cuda")
# ์์ธก
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image, origin
def save_image(image):
if image is None:
return None
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
image.save(temp_file, format="PNG")
return temp_file.name
def process_and_download(input_image):
result, original = fn(input_image)
if result is None:
return None, None
result_path = save_image(result)
original_path = save_image(original)
return [result_path, original_path], result_path
image = gr.Image(label="์ด๋ฏธ์ง ์
๋ก๋")
slider = ImageSlider(label="๋ฐฐ๊ฒฝ ์ ๊ฑฐ ๊ฒฐ๊ณผ", type="filepath")
png_output = gr.File(label="PNG ๋ค์ด๋ก๋")
examples = [
os.path.join(os.path.dirname(__file__), "์์ 1.png"),
os.path.join(os.path.dirname(__file__), "์์ 2.png"),
os.path.join(os.path.dirname(__file__), "์์ 3.png")
]
demo = gr.Interface(
process_and_download,
inputs=image,
outputs=[slider, png_output],
examples=examples,
title="๋ฐฐ๊ฒฝ ์ ๊ฑฐ",
description="์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด BiRefNet ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ฐฐ๊ฒฝ์ ์ ๊ฑฐํฉ๋๋ค. ๊ฒฐ๊ณผ๋ฅผ PNG ํ์ผ๋ก ๋ค์ด๋ก๋ํ ์ ์์ต๋๋ค."
)
if __name__ == "__main__":
demo.launch() |