Spaces:
Sleeping
Sleeping
File size: 2,283 Bytes
e546fea 958511f e546fea 958511f e546fea 958511f 3e75999 2d64873 958511f e546fea 3e75999 2d64873 3107813 2d64873 3107813 2d64873 e546fea c368dca 3107813 f61fa93 3107813 958511f f61fa93 3107813 958511f 20a2fe0 958511f f61fa93 e546fea 3107813 e546fea 3107813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
return (image, origin)
@spaces.GPU
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".",1)[0]+".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="birefnet", type="pil")
slider2 = ImageSlider(label="birefnet", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")
chameleon = load_img("butterfly.jpg", output_type="pil")
url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image"
)
tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["butterfly.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2,tab3], ["image", "text","png"], title="birefnet for background removal"
)
if __name__ == "__main__":
demo.launch(show_error=True) |