eee515 / app.py
kvinod15's picture
Update app.py
ae8d774 verified
raw
history blame
6.17 kB
import io
import numpy as np
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
import gradio as gr
from transformers import AutoModelForImageSegmentation, pipeline
# ----------------------------
# Global Setup and Model Loading
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the segmentation model (RMBG-2.0)
segmentation_model = AutoModelForImageSegmentation.from_pretrained(
'briaai/RMBG-2.0',
trust_remote_code=True
)
segmentation_model.to(device)
segmentation_model.eval()
# Transformation for segmentation (resizes to 512 for the model input)
image_size = (512, 512)
segmentation_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the depth estimation pipeline (Depth-Anything)
depth_pipeline = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf")
# ----------------------------
# Processing Functions
# ----------------------------
def segment_and_blur_background(input_image: Image.Image, blur_strength: int = 15, threshold: float = 0.5) -> Image.Image:
"""
Applies segmentation using the RMBG-2.0 model and composites the original image with
a Gaussian-blurred background based on an adjustable mask sensitivity threshold.
"""
image = input_image.convert("RGB")
orig_width, orig_height = image.size
# Preprocess image for segmentation (resize only for model inference)
input_tensor = segmentation_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = segmentation_model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Create binary mask with adjustable threshold (mask sensitivity)
binary_mask = (pred > threshold).float()
mask_pil = transforms.ToPILImage()(binary_mask).convert("L")
mask_pil = mask_pil.point(lambda p: 255 if p > 128 else 0)
mask_pil = mask_pil.resize((orig_width, orig_height), resample=Image.BILINEAR)
blurred_image = image.filter(ImageFilter.GaussianBlur(blur_strength))
final_image = Image.composite(image, blurred_image, mask_pil)
return final_image
def depth_based_lens_blur(input_image: Image.Image, max_blur: float = 2, num_bands: int = 40, invert_depth: bool = False) -> Image.Image:
"""
Applies a depth-based blur effect using a depth map produced by Depth-Anything.
The effect simulates a lens blur where the max_blur parameter controls the maximum blur.
This function uses the original input image size.
"""
# Use the original image for depth estimation (no resizing)
image_original = input_image.convert("RGB")
# Obtain depth map using the pipeline (assumes model accepts variable sizes)
results = depth_pipeline(image_original)
depth_map_image = results['depth']
depth_array = np.array(depth_map_image, dtype=np.float32)
d_min, d_max = depth_array.min(), depth_array.max()
depth_norm = (depth_array - d_min) / (d_max - d_min + 1e-8)
if invert_depth:
depth_norm = 1.0 - depth_norm
orig_rgba = image_original.convert("RGBA")
final_image = orig_rgba.copy()
band_edges = np.linspace(0, 1, num_bands + 1)
for i in range(num_bands):
band_min = band_edges[i]
band_max = band_edges[i + 1]
mid = (band_min + band_max) / 2.0
blur_radius_band = (1 - mid) * max_blur
blurred_version = orig_rgba.filter(ImageFilter.GaussianBlur(blur_radius_band))
band_mask = ((depth_norm >= band_min) & (depth_norm < band_max)).astype(np.uint8) * 255
band_mask_pil = Image.fromarray(band_mask, mode="L")
final_image = Image.composite(blurred_version, final_image, band_mask_pil)
return final_image.convert("RGB")
def process_image(input_image: Image.Image, effect: str, mask_sensitivity: float, blur_strength: float) -> Image.Image:
"""
Applies the selected effect:
- "Gaussian Blur Background": uses segmentation with adjustable mask sensitivity and blur strength.
- "Depth-based Lens Blur": applies depth-based blur with an adjustable maximum blur.
"""
if effect == "Gaussian Blur Background":
return segment_and_blur_background(input_image, blur_strength=int(blur_strength), threshold=mask_sensitivity)
elif effect == "Depth-based Lens Blur":
return depth_based_lens_blur(input_image, max_blur=blur_strength)
else:
return input_image
# ----------------------------
# Gradio Blocks Layout
# ----------------------------
with gr.Blocks(title="Interactive Blur Effects Demo") as demo:
gr.Markdown(
"""
# Interactive Blur Effects Demo
Upload an image and choose an effect below.
For **Gaussian Blur Background**, adjust the mask sensitivity (controls segmentation threshold)
and blur strength (controls Gaussian blur radius).
For **Depth-based Lens Blur**, the blur strength slider sets the maximum blur intensity.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
effect_choice = gr.Radio(
choices=["Gaussian Blur Background", "Depth-based Lens Blur"],
label="Select Effect",
value="Gaussian Blur Background"
)
mask_sensitivity_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.5, step=0.01,
label="Mask Sensitivity (for segmentation)"
)
blur_strength_slider = gr.Slider(
minimum=0, maximum=30, value=15, step=1,
label="Blur Strength"
)
run_button = gr.Button("Apply Effect")
with gr.Column():
output_image = gr.Image(type="pil", label="Output Image")
run_button.click(
fn=process_image,
inputs=[input_image, effect_choice, mask_sensitivity_slider, blur_strength_slider],
outputs=output_image
)
if __name__ == "__main__":
demo.launch()