JnanaVenkataSubhash commited on
Commit
2a8387c
·
verified ·
1 Parent(s): ac1b947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -49
app.py CHANGED
@@ -1,62 +1,103 @@
 
 
1
  import gradio as gr
2
  from PIL import Image, ImageFilter
3
- import torch
4
- from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
5
  import numpy as np
 
6
 
7
- # Load the device (use CPU or GPU)
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- # Initialize the model and processor
11
- image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
12
- model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device)
13
 
14
- # Function to apply background blur based on depth
15
- def apply_background_blur(image: Image):
16
- # Convert the uploaded image to RGB if necessary
17
- image = image.convert("RGB")
18
-
19
- # Process the image with DepthPro model
20
- inputs = image_processor(images=image, return_tensors="pt").to(device)
21
-
22
- with torch.no_grad():
23
- outputs = model(**inputs)
24
-
25
- post_processed_output = image_processor.post_process_depth_estimation(
26
- outputs, target_sizes=[(image.height, image.width)],
27
- )
28
-
29
- # Get the predicted depth and normalize it
30
- depth = post_processed_output[0]["predicted_depth"]
31
- depth_np = depth.detach().cpu().numpy().squeeze()
32
- depth_normalized = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min())
33
-
34
- # Create a blurred image
35
- blurred_image = image.copy()
36
 
37
- # Apply variable Gaussian blur based on depth
38
- blur_strength = 20 # You can adjust this for overall blur strength
39
- blur_map = (depth_normalized * blur_strength).astype(int)
 
 
 
 
 
 
 
 
 
 
40
 
41
- for radius in range(1, blur_strength + 1):
 
 
 
 
 
 
 
 
42
  mask = (blur_map == radius)
43
  if np.any(mask):
44
- temp_image = image.copy()
45
- temp_image = temp_image.filter(ImageFilter.GaussianBlur(radius))
46
- blurred_image = Image.composite(temp_image, blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
47
-
48
  return blurred_image
49
 
50
- # Create Gradio interface
51
- def create_interface():
52
- # Gradio interface with image upload input and output for processed image
53
- gr.Interface(
54
- fn=apply_background_blur,
55
- inputs=gr.Image(type="pil", label="Upload Image"),
56
- outputs=gr.Image(type="pil", label="Blurred Image"),
57
- live=True
58
- ).launch()
59
-
60
- # Start the app
61
- if __name__ == "__main__":
62
- create_interface()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
  from PIL import Image, ImageFilter
5
+ import torchvision.transforms as transforms
6
+ from transformers import AutoModelForImageSegmentation, DepthProImageProcessorFast, DepthProForDepthEstimation
7
  import numpy as np
8
+ import io
9
 
10
+ # Load Models
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ HF_model_name = 'BiRefNet'
14
+ birefnet = AutoModelForImageSegmentation.from_pretrained(f'zhengpeng7/{HF_model_name}', trust_remote_code=True).to(device).eval()
15
+ print('BiRefNet (Segmentation) is ready to use.')
16
 
17
+ depth_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf")
18
+ depth_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval()
19
+ print('DepthPro (Blur) is ready to use.')
20
+
21
+ # Combined Image Transform
22
+ transform_image = transforms.Compose([
23
+ transforms.Resize((1024, 1024)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Refine Foreground (Placeholder)
29
+ def refine_foreground(image, mask):
30
+ return image # Implement your refinement logic here
31
+
32
+ # Segmentation Function
33
+ def segment_image(image):
34
+ input_image = transform_image(image).unsqueeze(0).to(device)
35
+ with torch.no_grad():
36
+ pred = birefnet(input_image)[-1].sigmoid().cpu()[0].squeeze()
37
+ mask = transforms.ToPILImage()(pred).resize(image.size)
38
+ image_masked = refine_foreground(image.copy(), mask)
39
+ image_masked.putalpha(mask)
40
+ return image_masked
41
 
42
+ # Blur Function
43
+ def apply_background_blur(image):
44
+ inputs = depth_processor(images=image, return_tensors="pt").to(device)
45
+ with torch.no_grad():
46
+ depth = depth_processor.post_process_depth_estimation(depth_model(**inputs), target_sizes=[(image.height, image.width)])[0]["predicted_depth"][0].cpu().squeeze()
47
+ depth_normalized = (depth - depth.min()) / (depth.max() - depth.min())
48
+ blur_map = (depth_normalized * 20).astype(int)
49
+ blurred_image = image.copy()
50
+ for radius in range(1, 21):
51
  mask = (blur_map == radius)
52
  if np.any(mask):
53
+ blurred_image = Image.composite(image.copy().filter(ImageFilter.GaussianBlur(radius)), blurred_image, Image.fromarray((mask * 255).astype(np.uint8)))
 
 
 
54
  return blurred_image
55
 
56
+ # Process Image Function
57
+ def process_image(image, action):
58
+ image = image.convert("RGB")
59
+ if action == "Segmentation":
60
+ return segment_image(image)
61
+ elif action == "Blur":
62
+ return apply_background_blur(image)
63
+ elif action == "Both":
64
+ return segment_image(image), apply_background_blur(image)
65
+ else:
66
+ return None
67
+
68
+ # Download Function
69
+ def download_image(image):
70
+ if image is None:
71
+ return None
72
+ if isinstance(image, tuple):
73
+ images = []
74
+ for img in image:
75
+ img_byte_arr = io.BytesIO()
76
+ img.save(img_byte_arr, format='PNG')
77
+ images.append(img_byte_arr.getvalue())
78
+ return images
79
+ else:
80
+ img_byte_arr = io.BytesIO()
81
+ image.save(img_byte_arr, format='PNG')
82
+ return img_byte_arr.getvalue()
83
+
84
+ # Gradio Interface
85
+ def gradio_interface(image, action):
86
+ result = process_image(image, action)
87
+ if action == "Both":
88
+ return download_image(result), result[0], result[1]
89
+ else:
90
+ return download_image(result), result
91
+
92
+ interface = gr.Interface(
93
+ fn=gradio_interface,
94
+ inputs=[gr.Image(type="pil", label="Upload Image"), gr.Dropdown(["Segmentation", "Blur", "Both"], label="Select Action")],
95
+ outputs=[
96
+ gr.File(label="Download Output"),
97
+ gr.Image(label="Output Image 1"),
98
+ gr.Image(label="Output Image 2", visible=False)
99
+ ],
100
+ live=False,
101
+ )
102
+
103
+ interface.launch()