danielecordano commited on
Commit
5c0ac6e
1 Parent(s): 16bb0cc

Add colour picker

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -1,17 +1,12 @@
1
  import gradio as gr
2
- from gradio_imageslider import ImageSlider
3
  from loadimg import load_img
4
- import spaces
5
  from transformers import AutoModelForImageSegmentation
6
  import torch
7
  from torchvision import transforms
 
8
 
9
  torch.set_float32_matmul_precision(["high", "highest"][0])
10
-
11
- birefnet = AutoModelForImageSegmentation.from_pretrained(
12
- "ZhengPeng7/BiRefNet", trust_remote_code=True
13
- )
14
- birefnet.to("cuda")
15
  transform_image = transforms.Compose(
16
  [
17
  transforms.Resize((1024, 1024)),
@@ -20,44 +15,35 @@ transform_image = transforms.Compose(
20
  ]
21
  )
22
 
23
-
24
- @spaces.GPU
25
- def fn(image):
26
  im = load_img(image, output_type="pil")
27
  im = im.convert("RGB")
28
  image_size = im.size
29
- origin = im.copy()
30
  image = load_img(im)
31
- input_images = transform_image(image).unsqueeze(0).to("cuda")
32
- # Prediction
33
  with torch.no_grad():
34
  preds = birefnet(input_images)[-1].sigmoid().cpu()
35
  pred = preds[0].squeeze()
36
  pred_pil = transforms.ToPILImage()(pred)
37
  mask = pred_pil.resize(image_size)
38
  image.putalpha(mask)
39
- return (image, origin)
40
-
41
-
42
- slider1 = ImageSlider(label="birefnet", type="pil")
43
- slider2 = ImageSlider(label="birefnet", type="pil")
 
 
 
44
  image = gr.Image(label="Upload an image")
45
  text = gr.Textbox(label="Paste an image URL")
46
-
47
-
48
  chameleon = load_img("chameleon.jpg", output_type="pil")
49
-
50
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
51
- tab1 = gr.Interface(
52
- fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image"
53
- )
54
-
55
- tab2 = gr.Interface(fn, inputs=text, outputs=slider2, examples=[url], api_name="text")
56
-
57
-
58
- demo = gr.TabbedInterface(
59
- [tab1, tab2], ["image", "text"], title="birefnet for background removal"
60
- )
61
 
62
  if __name__ == "__main__":
63
  demo.launch()
 
1
  import gradio as gr
 
2
  from loadimg import load_img
 
3
  from transformers import AutoModelForImageSegmentation
4
  import torch
5
  from torchvision import transforms
6
+ from PIL import Image, ImageColor
7
 
8
  torch.set_float32_matmul_precision(["high", "highest"][0])
9
+ birefnet = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
 
 
 
 
10
  transform_image = transforms.Compose(
11
  [
12
  transforms.Resize((1024, 1024)),
 
15
  ]
16
  )
17
 
18
+ def fn(image, color):
 
 
19
  im = load_img(image, output_type="pil")
20
  im = im.convert("RGB")
21
  image_size = im.size
 
22
  image = load_img(im)
23
+ input_images = transform_image(image).unsqueeze(0)
 
24
  with torch.no_grad():
25
  preds = birefnet(input_images)[-1].sigmoid().cpu()
26
  pred = preds[0].squeeze()
27
  pred_pil = transforms.ToPILImage()(pred)
28
  mask = pred_pil.resize(image_size)
29
  image.putalpha(mask)
30
+ color = ImageColor.getcolor(color, "RGBA")
31
+ new_image = Image.new("RGBA", image.size, color)
32
+ new_image.paste(image, (0, 0), image)
33
+ new_image = new_image.convert("RGB")
34
+ return new_image
35
+
36
+ outimage1 = gr.Image(format="png", label="Result")
37
+ outimage2 = gr.Image(format="png", label="Result")
38
  image = gr.Image(label="Upload an image")
39
  text = gr.Textbox(label="Paste an image URL")
40
+ color1 = gr.ColorPicker(value="#FFFFFF", label="Color")
41
+ color2 = gr.ColorPicker(value="#FFFFFF", label="Color")
42
  chameleon = load_img("chameleon.jpg", output_type="pil")
 
43
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
44
+ tab1 = gr.Interface(fn, inputs=[image, color1], outputs=outimage1, examples=[[chameleon, "#FFFFFF"]], api_name="image")
45
+ tab2 = gr.Interface(fn, inputs=[text, color2], outputs=outimage2, examples=[[url, "#FFFFFF"]], api_name="text")
46
+ demo = gr.TabbedInterface([tab1, tab2], ["image", "text"], title="birefnet for background coloring")
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  demo.launch()