Aryan Wadhawan commited on
Commit
6030ffa
·
1 Parent(s): bab3d34
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -42,11 +42,11 @@ def generate_mask(image):
42
  # Preprocess the image
43
  image = np.array(image.convert("RGB"))
44
  img = preprocess(image)
45
-
46
  input_size = [1024, 1024]
47
  im_shp = image.shape[0:2]
48
  im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
49
-
50
  # Replace F.upsample with F.interpolate
51
  im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
52
  image = torch.divide(im_tensor, 255.0)
@@ -66,32 +66,46 @@ def generate_mask(image):
66
 
67
  return output_mask
68
 
69
- # Define the final predict method to overlay the mask
70
  def predict(image):
71
  # Generate the mask
72
  mask = generate_mask(image)
73
-
74
  # Convert the image to RGBA (to support transparency)
75
- image = image.convert("RGBA")
76
-
77
- # Convert the mask into a binary mask where 255 is kept and 0 is transparent
78
- mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale (L mode)
79
-
80
- # Create a new image with transparency (RGBA)
81
  transparent_image = Image.new("RGBA", image.size)
82
-
83
- # Use the mask as transparency: paste the original image where the mask is white
84
- transparent_image.paste(image, mask=mask)
85
-
86
- return transparent_image
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Create the Gradio interface with custom output size for the display only (not affecting the saved image)
89
  iface = gr.Interface(
90
- fn=predict,
91
- inputs=gr.Image(type="pil"),
92
- outputs=gr.Image(type="pil", label="Edited Image", image_mode="RGBA", format="png"), # RGBA ensures PNG with transparency
 
 
 
93
  title="Background Removal with U2NET",
94
- description="Upload an image and remove the background"
95
  )
96
 
97
  if __name__ == "__main__":
 
42
  # Preprocess the image
43
  image = np.array(image.convert("RGB"))
44
  img = preprocess(image)
45
+
46
  input_size = [1024, 1024]
47
  im_shp = image.shape[0:2]
48
  im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
49
+
50
  # Replace F.upsample with F.interpolate
51
  im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
52
  image = torch.divide(im_tensor, 255.0)
 
66
 
67
  return output_mask
68
 
 
69
  def predict(image):
70
  # Generate the mask
71
  mask = generate_mask(image)
72
+
73
  # Convert the image to RGBA (to support transparency)
74
+ image_rgba = image.convert("RGBA")
75
+
76
+ # Create a binary mask from the generated mask and resize it to the image size
77
+ mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale
78
+
79
+ # Create a new image with transparency (RGBA) for the output with transparent background
80
  transparent_image = Image.new("RGBA", image.size)
81
+ transparent_image.paste(image_rgba, mask=mask)
82
+
83
+ # Create foreground and background masks
84
+ red_foreground = Image.new("RGBA", image.size, (255, 0, 0, 128)) # Red foreground with 50% opacity
85
+ blue_background = Image.new("RGBA", image.size, (0, 0, 255, 128)) # Blue background with 50% opacity
86
+
87
+ # Create an empty overlay image
88
+ overlay_image = Image.new("RGBA", image.size)
89
+
90
+ # Overlay the red and blue masks based on the mask
91
+ overlay_image.paste(blue_background, (0, 0)) # Fill the entire overlay with blue
92
+ overlay_image.paste(red_foreground, (0, 0), mask=mask) # Paste red where mask is white
93
+
94
+ # Combine the original image with the overlay at 50% opacity
95
+ combined_image = Image.blend(image_rgba, overlay_image, alpha=0.5)
96
+
97
+ return transparent_image, combined_image
98
 
99
+ # Create the Gradio interface with two outputs
100
  iface = gr.Interface(
101
+ fn=predict,
102
+ inputs=gr.Image(type="pil"),
103
+ outputs=[
104
+ gr.Image(type="pil", label="Transparent Background", image_mode="RGBA", format="png"), # Transparent output
105
+ gr.Image(type="pil", label="Overlay with Colors", image_mode="RGBA", format="png"), # Colored overlay output
106
+ ],
107
  title="Background Removal with U2NET",
108
+ description="Upload an image to remove the background and visualize it with an overlay."
109
  )
110
 
111
  if __name__ == "__main__":