CutItOut / app.py
K00B404's picture
Update app.py
893f4a7 verified
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from PIL import Image
from typing import Tuple, Dict, List
import cv2
from pathlib import Path
from briarmbg import MultiTargetBriaRMBG
from briacustom import MultiTargetBriaRMBG, ClothingType, GarmentFeatures
class ImageProcessor:
def __init__(self, model_path: str = "briaai/RMBG-1.4"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.net = MultiTargetBriaRMBG.from_pretrained(model_path)
self.net.to(self.device)
self.model_input_size = (1024, 1024)
def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
"""Prepare image for model input"""
# Convert numpy array to PIL
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB and resize
image = image.convert('RGB')
image = image.resize(self.model_input_size, Image.LANCZOS)
# Convert to tensor
im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return im_tensor.to(self.device)
def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image:
"""Convert model output mask to PIL Image"""
# Resize mask to original image size
mask = F.interpolate(mask, size=target_size, mode='bilinear')
mask = torch.squeeze(mask, 0)
# Normalize mask values
mask = (mask - mask.min()) / (mask.max() - mask.min())
# Convert to PIL Image
mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8)
return Image.fromarray(mask_np[0])
def process_image(self,
image: np.ndarray,
mode: str = "background_removal",
clothing_options: Dict = None) -> Dict[str, Image.Image]:
"""Main processing function"""
# Get original size
orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
orig_size = orig_image.size
# Preprocess
input_tensor = self.preprocess_image(image)
# Model inference
results = self.net(input_tensor, mode=mode)
# Process different outputs based on mode
outputs = {}
if mode == "background_removal" or mode == "all":
# Get foreground mask
fg_mask = self.postprocess_mask(results["foreground"], orig_size)
# Create transparent background image
transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0))
transparent.paste(orig_image, mask=fg_mask)
outputs["removed_background"] = transparent
# Extract background if requested
if mode == "all":
bg_mask = self.postprocess_mask(results["background"], orig_size)
background = Image.new("RGBA", orig_size, (0, 0, 0, 0))
background.paste(orig_image, mask=bg_mask)
outputs["background"] = background
if mode == "clothing" or mode == "all":
clothing_mask = self.postprocess_mask(results["clothing"], orig_size)
if clothing_options:
# Apply clothing modifications
modified = self.apply_clothing_modifications(
orig_image,
clothing_mask,
clothing_options
)
outputs["modified_clothing"] = modified
# Extract original clothing
clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0))
clothing.paste(orig_image, mask=clothing_mask)
outputs["clothing"] = clothing
return outputs
def apply_clothing_modifications(self,
image: Image.Image,
mask: Image.Image,
options: Dict) -> Image.Image:
"""Apply clothing modifications based on options"""
if "color" in options:
image = self.change_clothing_color(image, mask, options["color"])
if "pattern" in options:
image = self.apply_pattern(image, mask, options["pattern"])
if "style_transfer" in options:
image = self.transfer_clothing_style(image, mask, options["style_transfer"])
return image
def create_ui() -> gr.Blocks:
"""Create the Gradio UI"""
processor = ImageProcessor()
with gr.Blocks() as app:
gr.Markdown("# Advanced Background and Clothing Removal")
with gr.Tab("Background Removal"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
remove_bg_btn = gr.Button("Remove Background")
with gr.Column():
output_image = gr.Image(label="Result", type="pil")
remove_bg_btn.click(
fn=lambda img: processor.process_image(img)["removed_background"],
inputs=[input_image],
outputs=[output_image]
)
with gr.Tab("Clothing Manipulation"):
with gr.Row():
with gr.Column():
cloth_input = gr.Image(label="Input Image", type="numpy")
with gr.Accordion("Clothing Options"):
color_picker = gr.ColorPicker(label="New Color")
pattern_choice = gr.Dropdown(
choices=["Stripes", "Dots", "Floral"],
label="Pattern"
)
style_image = gr.Image(label="Style Reference", type="numpy")
process_clothing_btn = gr.Button("Process Clothing")
with gr.Column():
cloth_output = gr.Image(label="Modified Clothing", type="pil")
def process_clothing(image, color, pattern, style):
options = {}
if color:
options["color"] = color
if pattern:
options["pattern"] = pattern
if style is not None:
options["style_transfer"] = style
return processor.process_image(
image,
mode="clothing",
clothing_options=options
)["modified_clothing"]
process_clothing_btn.click(
fn=process_clothing,
inputs=[cloth_input, color_picker, pattern_choice, style_image],
outputs=[cloth_output]
)
# Examples section
examples_dir = Path("./examples")
examples = [
[str(examples_dir / f"example_{i}.jpg")]
for i in range(1, 4)
if (examples_dir / f"example_{i}.jpg").exists()
]
if examples:
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[output_image],
fn=lambda img: processor.process_image(img)["removed_background"],
cache_examples=True
)
return app
if __name__ == "__main__":
app = create_ui()
app.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
debug=True
)