# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py # The original license file is LICENSE.ControlNet in this repo. from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed from flax.training.common_utils import shard from flax.jax_utils import replicate from diffusers.utils import load_image import jax.numpy as jnp import jax import cv2 from PIL import Image import numpy as np import gradio as gr import os if gr.__version__ != "3.28.3": #doesn't work... os.system("pip uninstall -y gradio") os.system("pip install gradio==3.28.3") description = """ Our project is to use diffusion model to change the texture of our robotic arm simulation. To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet. Therefore, we are able to change our simulation texture, and still keeping the image composition. Our objectif for the sprint is to perform data augmentation using ControlNet. We then look for having a model that can augment an image quickly. To do so, we trained many Controlnets from scratch with different datasets : * [Coyo-700M](https://github.com/kakaobrain/coyo-dataset) * [Bridge](https://sites.google.com/view/bridgedata) A method to accelerate the inference of diffusion model is by simply generating small images. So we decided to work with low resolution images. After downloading the datasets, we processed them by resizing images to a 128 resolution. The smallest side of the image (width or height) is resized to 128 and the other side is resized keeping the initial ratio. After, we retrieve the Canny Edge Map of the images. We performed this preprocess for every datasets we use during the sprint. We train four different Controlnets. For each one of them, we processed the datasets differently. You can find the description of the processing in the readme file attached to the model repo [Our ControlNet repo](https://huggingface.co/Baptlem/baptlem-controlnet) For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds! We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon. """ traj_description = """ We generated a trajectory of our simulated environment. We will then use it with our different models. We made these videos on our Titan RTX node. The prompt we use for every video is "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background" """ def create_key(seed=0): return jax.random.PRNGKey(seed) def load_controlnet(controlnet_version): controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "Baptlem/baptlem-controlnet", subfolder=controlnet_version, from_flax=True, dtype=jnp.float32, ) return controlnet, controlnet_params def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): controlnet, controlnet_params = load_controlnet(controlnet_version) scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( sb_path, subfolder="scheduler" ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( sb_path, controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 ) pipe.scheduler = scheduler params["controlnet"] = controlnet_params params["scheduler"] = scheduler_params return pipe, params controlnet_path = "Baptlem/baptlem-controlnet" controlnet_version = "coyo-500k" # Constants low_threshold = 100 high_threshold = 200 print(os.path.abspath('.')) print(os.listdir(".")) print("Gradio version:", gr.__version__) # pipe.enable_xformers_memory_efficient_attention() # pipe.enable_model_cpu_offload() # pipe.enable_attention_slicing() print("Loaded models...") def pipe_inference( image, prompt, is_canny=False, num_samples=4, resolution=128, num_inference_steps=50, guidance_scale=7.5, model="coyo-500k", seed=0, negative_prompt="", ): print("Loading pipe") pipe, params = load_sb_pipe(model) if not isinstance(image, np.ndarray): image = np.array(image) processed_image = resize_image(image, resolution) #-> PIL if not is_canny: resized_image, processed_image = preprocess_canny(processed_image, resolution) rng = create_key(seed) rng = jax.random.split(rng, jax.device_count()) prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) print("Inference...") output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=rng, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, neg_prompt_ids=negative_prompt_ids, jit=True, ).images print("Finished inference...") # all_outputs = [] # all_outputs.append(image) # if not is_canny: # all_outputs.append(resized_image) # for image in output.images: # all_outputs.append(image) all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) return all_outputs def resize_image(image, resolution): if not isinstance(image, np.ndarray): image = np.array(image) h, w = image.shape[:2] ratio = w/h if ratio > 1 : resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) elif ratio < 1 : resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) else: resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) return Image.fromarray(resized_image) def preprocess_canny(image, resolution=128): if not isinstance(image, np.ndarray): image = np.array(image) processed_image = cv2.Canny(image, low_threshold, high_threshold) processed_image = processed_image[:, :, None] processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) resized_image = Image.fromarray(image) processed_image = Image.fromarray(processed_image) return resized_image, processed_image def create_demo(process, max_images=12, default_num_images=4): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type='numpy') prompt = gr.Textbox(label='Prompt') run_button = gr.Button(label='Run') with gr.Accordion('Advanced options', open=False): is_canny = gr.Checkbox( label='Is canny', value=False) num_samples = gr.Slider(label='Images', minimum=1, maximum=max_images, value=default_num_images, step=1) """ canny_low_threshold = gr.Slider( label='Canny low threshold', minimum=1, maximum=255, value=100, step=1) canny_high_threshold = gr.Slider( label='Canny high threshold', minimum=1, maximum=255, value=200, step=1) """ resolution = gr.Slider(label='Resolution', minimum=128, maximum=128, value=128, step=1) num_steps = gr.Slider(label='Steps', minimum=1, maximum=100, value=20, step=1) guidance_scale = gr.Slider(label='Guidance Scale', minimum=0.1, maximum=30.0, value=7.5, step=0.1) model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo28-bridge4"], value="coyo-500k", label="Model used for inference", info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet") seed = gr.Slider(label='Seed', minimum=-1, maximum=2147483647, step=1, randomize=True) n_prompt = gr.Textbox( label='Negative Prompt', value= 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' ) with gr.Column(): result = gr.Gallery(label='Output', show_label=False, elem_id='gallery').style(grid=2, height='auto') with gr.Row(): gr.Markdown(description) with gr.Row(): with gr.Column(): gr.Markdown(traj_description) with gr.Column(): gr.Video("./trajectory_hf/trajectory.avi", format="avi", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("Trajectory processed with coyo-500k model :") with gr.Column(): gr.Video("./trajectory_hf/trajectory_coyo-500k.avi", format="avi", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("Trajectory processed with bridge-2M model :") with gr.Column(): gr.Video("./trajectory_hf/trajectory_bridge-2M.avi", format="avi", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("Trajectory processed with coyo1M-bridge2M model :") with gr.Column(): gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi", format="avi", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("Trajectory processed with coyo28-bridge4 model :") with gr.Column(): gr.Video("./trajectory_hf/trajectory_coyo28-bridge4.avi", format="avi", interactive=False) inputs = [ input_image, prompt, is_canny, num_samples, resolution, #canny_low_threshold, #canny_high_threshold, num_steps, guidance_scale, model, seed, n_prompt, ] prompt.submit(fn=process, inputs=inputs, outputs=result) run_button.click(fn=process, inputs=inputs, outputs=result, api_name='canny') return demo if __name__ == '__main__': pipe_inference demo = create_demo(pipe_inference) demo.queue().launch() # gr.Interface(create_demo).launch()