import spaces import gradio as gr import cv2 import numpy as np import mediapipe as mp from mediapipe.tasks import python from mediapipe.tasks.python import vision from mediapipe.python._framework_bindings import image as image_module _Image = image_module.Image from mediapipe.python._framework_bindings import image_frame _ImageFormat = image_frame.ImageFormat import torch from diffusers import StableDiffusionPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel from PIL import Image from compel import Compel from diffusers import EulerDiscreteScheduler # Device configuration device = torch.device("cpu") # Ensure everything is set to run on CPU # Constants for colors BG_COLOR = (0, 0, 0, 255) # gray with full opacity MASK_COLOR = (255, 255, 255, 255) # white with full opacity # Create the options that will be used for ImageSegmenter base_options = python.BaseOptions(model_asset_path='emirhan.tflite') options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True) # Initialize ControlNet inpainting pipeline controlnet = ControlNetModel.from_pretrained( 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float32, # Use float32 for CPU ).to(device) pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( 'runwayml/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, torch_dtype=torch.float32, # Use float32 for CPU ).to(device) # Set the K_EULER scheduler pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) # Function to segment hair and generate mask def segment_hair(image): rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA) rgba_image[:, :, 3] = 0 # Set alpha channel to empty # Create MP Image object from numpy array mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image) # Create the image segmenter with vision.ImageSegmenter.create_from_options(options) as segmenter: # Retrieve the masks for the segmented image segmentation_result = segmenter.segment(mp_image) category_mask = segmentation_result.category_mask # Generate solid color images for showing the output segmentation mask. image_data = mp_image.numpy_view() fg_image = np.zeros(image_data.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image_data.shape, dtype=np.uint8) bg_image[:] = BG_COLOR condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2 output_image = np.where(condition, fg_image, bg_image) return output_image # Return the RGBA mask # Function to resize image while maintaining aspect ratio def resize_image(image, max_size=1536): h, w = image.shape[:2] if max(h, w) > max_size: scale = max_size / max(h, w) new_size = (int(w * scale), int(h * scale)) image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA) return image # Function to inpaint the hair area using ControlNet def inpaint_hair(image, prompt): # Only resize the input image if it's larger than 1536 in any dimension h, w = image.shape[:2] if max(h, w) > 1536: image = resize_image(image) # Segment hair to get the mask mask = segment_hair(image) # Convert to PIL image for the inpainting pipeline image_pil = Image.fromarray(image) mask_pil = Image.fromarray(cv2.cvtColor(mask, cv2.COLOR_RGBA2GRAY)) mask_pil = mask_pil.convert("L") # Ensure it's a single-channel (grayscale) image # Prepare the inpainting condition image_np = np.array(image_pil).astype(np.float32) / 255.0 mask_np = np.array(mask_pil.convert("L")).astype(np.float32) / 255.0 image_np[mask_np > 0.5] = -1.0 # Set as masked pixel inpaint_condition = torch.from_numpy(np.expand_dims(image_np, 0).transpose(0, 3, 1, 2)).to(device) # Generate inpainted image generator = torch.manual_seed(42) negative_prompt = "lowres, bad quality, poor quality" output = pipe( prompt=prompt, negative_prompt=negative_prompt, image=image_pil, mask_image=mask_pil, control_image=inpaint_condition, num_inference_steps=25, guidance_scale=7.5, generator=generator ).images[0] return np.array(output) # Gradio interface iface = gr.Interface( fn=inpaint_hair, inputs=[ gr.Image(type="numpy"), gr.Textbox(label="Prompt", placeholder="Describe the desired inpainting result...") ], outputs=gr.Image(type="numpy"), title="Hair Inpainting with ControlNet", description="Upload an image, and provide a prompt to inpaint the hair area using ControlNet.", examples=[["example.jpeg", "dreadlocks"], ["example2.jpg", "pink hair"]] ) if __name__ == "__main__": iface.launch()