from controlnet_aux import OpenposeDetector from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers import UniPCMultistepScheduler import gradio as gr import torch import base64 from io import BytesIO from PIL import Image # live conditioning canvas_html = "" load_js = """ async () => { const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/pose-gradio.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_image = """ async (image_in_img, prompt, image_file_live_opt, live_conditioning) => { const canvasEl = document.getElementById("canvas-root"); const data = canvasEl? canvasEl._data : null; return [image_in_img, prompt, image_file_live_opt, data] } """ # Constants low_threshold = 100 high_threshold = 200 # Models pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") ''' controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 ) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16 ) ''' controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.get_default_dtype() ) pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=True, torch_dtype=torch.get_default_dtype() ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) # This command loads the individual model components on GPU on-demand. So, we don't # need to explicitly call pipe.to("cuda"). #pipe.enable_model_cpu_offload() # xformers #pipe.enable_xformers_memory_efficient_attention() # Generator seed, generator = torch.manual_seed(0) def get_pose(image): return pose_model(image) def generate_images(image, prompt, image_file_live_opt='file', live_conditioning=None): if image is None and 'image' not in live_conditioning: raise gr.Error("Please provide an image") try: if image_file_live_opt == 'file': pose = get_pose(image) elif image_file_live_opt == 'webcam': base64_img = live_conditioning['image'] image_data = base64.b64decode(base64_img.split(',')[1]) pose = Image.open(BytesIO(image_data)).convert( 'RGB').resize((512, 512)) output = pipe( prompt, pose, generator=generator, num_images_per_prompt=3, num_inference_steps=20, ) all_outputs = [] all_outputs.append(pose) for image in output.images: all_outputs.append(image) return all_outputs except Exception as e: raise gr.Error(str(e)) def toggle(choice): if choice == "file": return gr.update(visible=True, value=None), gr.update(visible=False, value=None) elif choice == "webcam": return gr.update(visible=False, value=None), gr.update(visible=True, value=canvas_html) with gr.Blocks() as blocks: gr.Markdown(""" ## Generate controlled outputs with ControlNet and Stable Diffusion This Space uses pose estimated lines as the additional conditioning [Check out our blog to see how this was done (and train your own controlnet)](https://huggingface.co/blog/train-your-controlnet) """) with gr.Row(): live_conditioning = gr.JSON(value={}, visible=False) with gr.Column(): image_file_live_opt = gr.Radio(["file", "webcam"], value="file", label="How would you like to upload your image?") image_in_img = gr.Image(source="upload", visible=True, type="pil") canvas = gr.HTML(None, elem_id="canvas_html", visible=False) image_file_live_opt.change(fn=toggle, inputs=[image_file_live_opt], outputs=[image_in_img, canvas], queue=False) prompt = gr.Textbox( label="Enter your prompt", max_lines=1, placeholder="best quality, extremely detailed", ) run_button = gr.Button("Generate") with gr.Column(): gallery = gr.Gallery().style(grid=[2], height="auto") run_button.click(fn=generate_images, inputs=[image_in_img, prompt, image_file_live_opt, live_conditioning], outputs=[gallery], _js=get_js_image) blocks.load(None, None, None, _js=load_js) blocks.launch(debug=True)