Sutirtha commited on
Commit
892cf9d
·
verified ·
1 Parent(s): a68f3d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -47
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from PIL import Image, ImageEnhance
3
  import numpy as np
4
  import cv2
5
  from lang_sam import LangSAM
@@ -8,9 +8,14 @@ from color_matcher.normalizer import Normalizer
8
  import torch
9
 
10
  # Load the LangSAM model
11
- model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
 
 
 
 
 
 
12
 
13
- # Function to apply color matching based on reference image
14
  def apply_color_matching(source_img_np, ref_img_np):
15
  # Initialize ColorMatcher
16
  cm = ColorMatcher()
@@ -23,60 +28,129 @@ def apply_color_matching(source_img_np, ref_img_np):
23
 
24
  return img_res
25
 
26
- # Function to extract sky and apply color matching using a reference image
27
- def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
28
- # Use LangSAM to predict the mask for the sky
29
- masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
30
-
31
- # Ensure masks is converted from tensor to NumPy
32
- masks_np = masks[0].cpu().numpy() # Convert the tensor to NumPy array
33
-
34
- # Convert the mask to a binary format and create a mask image
35
- sky_mask = (masks_np > 0).astype(np.uint8) * 255 # Ensure it's a binary mask
36
 
37
- # Convert PIL image to numpy array for processing
38
- img_np = np.array(image_pil)
 
 
39
 
40
- # Convert sky mask to 3-channel format to blend with the original image
41
- sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
42
 
43
- # Extract the sky region
44
- sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
 
45
 
46
- # Convert the reference image to a numpy array
47
- ref_img_np = np.array(reference_image_pil)
 
 
48
 
49
- # Apply color matching using the reference image to the extracted sky region
50
- sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Combine the color-matched sky region back into the original image
53
- result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
 
 
 
 
 
 
 
 
54
 
55
- # Convert the result back to PIL Image for final output
56
- result_img_pil = Image.fromarray(result_img_np)
57
 
58
- return result_img_pil
59
-
60
- # Gradio Interface
61
- def gradio_interface():
62
- # Gradio function to be called on input
63
- def process_image(source_img, ref_img):
64
- # Extract sky and apply color matching using reference image
65
- result_img_pil = extract_and_color_match_sky(source_img, ref_img)
66
- return result_img_pil
67
-
68
- # Define Gradio input components
69
- inputs = [
70
- gr.Image(type="pil", label="Source Image"),
71
- gr.Image(type="pil", label="Reference Image") # Second input for reference image
72
- ]
73
-
74
- # Define Gradio output component
75
- outputs = gr.Image(type="pil", label="Resulting Image")
76
 
77
- # Launch Gradio app
78
- gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Run the Gradio Interface
81
  if __name__ == "__main__":
82
  gradio_interface()
 
1
  import gradio as gr
2
+ from PIL import Image
3
  import numpy as np
4
  import cv2
5
  from lang_sam import LangSAM
 
8
  import torch
9
 
10
  # Load the LangSAM model
11
+ model = LangSAM() # Use the default model or specify custom checkpoint if necessary
12
+
13
+ def extract_mask(image_pil, text_prompt):
14
+ masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
15
+ masks_np = masks[0].cpu().numpy()
16
+ mask = (masks_np > 0).astype(np.uint8) * 255 # Binary mask
17
+ return mask
18
 
 
19
  def apply_color_matching(source_img_np, ref_img_np):
20
  # Initialize ColorMatcher
21
  cm = ColorMatcher()
 
28
 
29
  return img_res
30
 
31
+ def process_image(current_image_pil, prompt, replacement_image_pil, color_ref_image_pil, image_history):
32
+ # Check if current_image_pil is None
33
+ if current_image_pil is None:
34
+ return None, "No current image to edit.", image_history, None
 
 
 
 
 
 
35
 
36
+ # Save current image to history for undo
37
+ if image_history is None:
38
+ image_history = []
39
+ image_history.append(current_image_pil.copy())
40
 
41
+ # Extract mask
42
+ mask = extract_mask(current_image_pil, prompt)
43
 
44
+ # Check if mask is valid
45
+ if mask.sum() == 0:
46
+ return current_image_pil, f"No mask detected for prompt: {prompt}", image_history, current_image_pil
47
 
48
+ # Proceed with replacement or color matching
49
+ current_image_np = np.array(current_image_pil)
50
+ mask_3ch = cv2.merge([mask, mask, mask])
51
+ result_image_np = current_image_np.copy()
52
 
53
+ # If replacement image is provided
54
+ if replacement_image_pil is not None:
55
+ # Resize replacement image to fit the mask area
56
+ # Get bounding box of the mask
57
+ y_indices, x_indices = np.where(mask > 0)
58
+ if y_indices.size == 0 or x_indices.size == 0:
59
+ # No mask detected
60
+ return current_image_pil, f"No mask detected for prompt: {prompt}", image_history, current_image_pil
61
+ y_min, y_max = y_indices.min(), y_indices.max()
62
+ x_min, x_max = x_indices.min(), x_indices.max()
63
+
64
+ # Extract the region of interest
65
+ mask_height = y_max - y_min + 1
66
+ mask_width = x_max - x_min + 1
67
+
68
+ # Resize replacement image to fit mask area
69
+ replacement_image_resized = replacement_image_pil.resize((mask_width, mask_height))
70
+ replacement_image_np = np.array(replacement_image_resized)
71
+
72
+ # Create a mask for the ROI
73
+ mask_roi = mask[y_min:y_max+1, x_min:x_max+1]
74
+ mask_roi_3ch = cv2.merge([mask_roi, mask_roi, mask_roi])
75
+
76
+ # Replace the masked area with the replacement image
77
+ result_image_np[y_min:y_max+1, x_min:x_max+1] = np.where(mask_roi_3ch > 0, replacement_image_np, result_image_np[y_min:y_max+1, x_min:x_max+1])
78
 
79
+ # If color reference image is provided
80
+ if color_ref_image_pil is not None:
81
+ # Extract the masked area
82
+ masked_region = cv2.bitwise_and(result_image_np, mask_3ch)
83
+ # Convert color reference image to numpy
84
+ color_ref_image_np = np.array(color_ref_image_pil)
85
+ # Apply color matching
86
+ color_matched_region = apply_color_matching(masked_region, color_ref_image_np)
87
+ # Combine the color matched region back into the result image
88
+ result_image_np = np.where(mask_3ch > 0, color_matched_region, result_image_np)
89
 
90
+ # Convert result back to PIL Image
91
+ result_image_pil = Image.fromarray(result_image_np)
92
 
93
+ # Update current_image_pil
94
+ current_image_pil = result_image_pil
95
+
96
+ return current_image_pil, f"Applied changes for prompt: {prompt}", image_history, current_image_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ def undo(image_history):
99
+ if image_history and len(image_history) > 1:
100
+ # Pop the last image
101
+ image_history.pop()
102
+ # Return the previous image
103
+ current_image_pil = image_history[-1]
104
+ return current_image_pil, image_history, current_image_pil
105
+ elif image_history and len(image_history) == 1:
106
+ current_image_pil = image_history[0]
107
+ return current_image_pil, image_history, current_image_pil
108
+ else:
109
+ # Cannot undo
110
+ return None, [], None
111
 
112
+ def gradio_interface():
113
+ with gr.Blocks() as demo:
114
+ # Define the state variables
115
+ image_history = gr.State([])
116
+ current_image_pil = gr.State(None)
117
+
118
+ gr.Markdown("## Continuous Image Editing with LangSAM")
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ initial_image = gr.Image(type="pil", label="Upload Image")
123
+ prompt = gr.Textbox(lines=1, placeholder="Enter prompt for object detection", label="Prompt")
124
+ replacement_image = gr.Image(type="pil", label="Replacement Image (optional)", optional=True)
125
+ color_ref_image = gr.Image(type="pil", label="Color Reference Image (optional)", optional=True)
126
+ apply_button = gr.Button("Apply Changes")
127
+ undo_button = gr.Button("Undo")
128
+ with gr.Column():
129
+ current_image_display = gr.Image(type="pil", label="Edited Image", interactive=False)
130
+ status = gr.Textbox(lines=2, interactive=False, label="Status")
131
+
132
+ def initialize_image(initial_image_pil):
133
+ # Initialize image history with the initial image
134
+ if initial_image_pil is not None:
135
+ image_history = [initial_image_pil]
136
+ current_image_pil = initial_image_pil
137
+ return current_image_pil, image_history, initial_image_pil
138
+ else:
139
+ return None, [], None
140
+
141
+ # When the initial image is uploaded, initialize the image history
142
+ initial_image.upload(fn=initialize_image, inputs=initial_image, outputs=[current_image_pil, image_history, current_image_display])
143
+
144
+ # Apply button click
145
+ apply_button.click(fn=process_image,
146
+ inputs=[current_image_pil, prompt, replacement_image, color_ref_image, image_history],
147
+ outputs=[current_image_pil, status, image_history, current_image_display])
148
+
149
+ # Undo button click
150
+ undo_button.click(fn=undo, inputs=image_history, outputs=[current_image_pil, image_history, current_image_display])
151
+
152
+ demo.launch(share=True)
153
+
154
  # Run the Gradio Interface
155
  if __name__ == "__main__":
156
  gradio_interface()