import random import gradio as gr import torch from diffusers.utils import load_image from PIL import Image import numpy as np import base64 from io import BytesIO from mediapipe_face_common import generate_annotation from diffusers import ( ControlNetModel, StableDiffusionControlNetPipeline, ) # Download the SD 1.5 model from HF device = torch.device("cuda" if torch.cuda.is_available() else "cpu") controlnet = ControlNetModel.from_pretrained( "CrucibleAI/ControlNetMediaPipeFace", torch_dtype=torch.float16, variant="fp16") model = StableDiffusionControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 ) model = model.to(device) model.enable_model_cpu_offload() canvas_html = "" load_js = """ async () => { const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/face-canvas.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_image = """ async (input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta, image_file_live_opt, live_conditioning) => { const canvasEl = document.getElementById("canvas-root"); const imageData = canvasEl? canvasEl._data : null; return [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta, image_file_live_opt, imageData]; } """ def pad_image(input_image): pad_w, pad_h = np.max(((2, 2), np.ceil( np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size im_padded = Image.fromarray( np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) w, h = im_padded.size if w == h: return im_padded elif w > h: new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0)) new_image.paste(im_padded, (0, (w - h) // 2)) return new_image else: new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0)) new_image.paste(im_padded, ((h - w) // 2, 0)) return new_image def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength: float, scale, seed: int, eta, image_file_live_opt="file", live_conditioning=None): if input_image is None and 'image' not in live_conditioning: raise gr.Error("Please provide an image") try: if image_file_live_opt == 'file': # Resize before annotation so that we can keep our line-widths consistent with the training data. input_image = pad_image(input_image.convert('RGB')).resize((512, 512)) empty = generate_annotation(np.array(input_image), max_faces, min_confidence) visualization = Image.fromarray(empty) elif image_file_live_opt == 'webcam': base64_img = live_conditioning['image'] image_data = base64.b64decode(base64_img.split(',')[1]) visualization = Image.open(BytesIO(image_data)).convert('RGB').resize((512, 512)) if seed == -1: seed = random.randint(0, 2147483647) generator = torch.Generator(device).manual_seed(seed) output = model(prompt=prompt + ' ' + a_prompt, negative_prompt=n_prompt, image=visualization, generator=generator, num_images_per_prompt=num_samples, num_inference_steps=ddim_steps, controlnet_conditioning_scale=float(strength), guidance_scale=scale, eta=eta, ) results = [visualization] + output.images return results except Exception as e: raise gr.Error(str(e)) # switch between file upload and webcam def toggle(choice): if choice == "file": return gr.update(visible=True, value=None), gr.update(visible=False, value=None) elif choice == "webcam": return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html) block = gr.Blocks().queue() with block: # hidden JSON component to store live conditioning live_conditioning = gr.JSON(value={}, visible=False) with gr.Row(): gr.Markdown("## Control Stable Diffusion with a Facial Pose") with gr.Row(): with gr.Column(): image_file_live_opt = gr.Radio(["file", "webcam"], value="file", label="How would you like to upload your image?") input_image = gr.Image(source="upload", visible=True, type="pil") canvas = gr.HTML(None, elem_id="canvas_html", visible=False) image_file_live_opt.change(fn=toggle, inputs=[image_file_live_opt], outputs=[input_image, canvas], queue=False) prompt = gr.Textbox(label="Prompt") run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): num_samples = gr.Slider( label="Images", minimum=1, maximum=4, value=1, step=1) max_faces = gr.Slider( label="Max Faces", minimum=1, maximum=10, value=5, step=1) min_confidence = gr.Slider( label="Min Confidence", minimum=0.01, maximum=1.0, value=0.5, step=0.01) strength = gr.Slider( label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) guess_mode = gr.Checkbox(label='Guess Mode', value=False) ddim_steps = gr.Slider( label="Steps", minimum=1, maximum=100, value=20, step=1) scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) eta = gr.Number(label="eta (DDIM)", value=0.0) a_prompt = gr.Textbox( label="Added Prompt", value='best quality, extremely detailed') n_prompt = gr.Textbox(label="Negative Prompt", value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') with gr.Column(): result_gallery = gr.Gallery( label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') ips = [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta] run_button.click(fn=process, inputs=ips + [image_file_live_opt, live_conditioning], outputs=[result_gallery], _js=get_js_image) # load js for live conditioning block.load(None, None, None, _js=load_js) gr.Examples(fn=process, examples=[ ["./examples/two2.jpeg", "Highly detailed photograph of two clowns", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0], ["./examples/two.jpeg", "a photo of two silly men", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0], ["./examples/pedro-512.jpg", "Highly detailed photograph of young woman smiling, with palm trees in the background", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0], ["./examples/image1.jpg", "Highly detailed photograph of a scary clown", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0], ["./examples/image0.jpg", "Highly detailed photograph of Madonna", "best quality, extremely detailed", "cartoon, disfigured, bad art, deformed, poorly drawn, extra limbs, weird colors, blurry, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", 10, 0.4, 3, 20, False, 1.0, 9.0, -1, 0.0], ], inputs=ips, outputs=[result_gallery], cache_examples=True) block.launch(server_name='0.0.0.0')