File size: 2,283 Bytes
e546fea
 
 
 
 
 
 
 
958511f
e546fea
958511f
 
 
 
 
 
 
 
 
 
 
e546fea
 
958511f
 
3e75999
2d64873
 
 
 
 
 
958511f
e546fea
 
 
 
 
3e75999
 
2d64873
3107813
2d64873
3107813
2d64873
 
 
 
 
e546fea
 
c368dca
3107813
 
f61fa93
3107813
 
958511f
f61fa93
 
3107813
958511f
20a2fe0
958511f
 
 
f61fa93
e546fea
3107813
 
 
 
e546fea
 
3107813
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def fn(image):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image = process(im)
    return (image, origin)

@spaces.GPU
def process(image):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image
  
def process_file(f):
    name_path = f.rsplit(".",1)[0]+".png"
    im = load_img(f, output_type="pil")
    im = im.convert("RGB")
    transparent = process(im)
    transparent.save(name_path)
    return name_path

slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")


chameleon = load_img("butterfly.jpg", output_type="pil")

url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
    fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image"
)

tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["butterfly.jpg"], api_name="png")


demo = gr.TabbedInterface(
    [tab1, tab2,tab3], ["image", "text","png"], title="birefnet for background removal"
)

if __name__ == "__main__":
    demo.launch(show_error=True)