Spaces:
Build error
Build error
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import jax.numpy as jnp | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from PIL import Image | |
| from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel | |
| import cv2 | |
| # load control net and stable diffusion v1-5 | |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
| "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd" | |
| ) | |
| pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 | |
| ) | |
| def create_key(seed=0): | |
| return jax.random.PRNGKey(seed) | |
| def process_mask(image): | |
| mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| mask = cv2.resize(mask,(512,512)) | |
| return mask | |
| def infer(prompts, negative_prompts, image): | |
| params["controlnet"] = controlnet_params | |
| num_samples = 1 #jax.device_count() | |
| rng = create_key(0) | |
| rng = jax.random.split(rng, jax.device_count()) | |
| im = process_mask(image) | |
| mask = Image.fromarray(im) | |
| prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) | |
| negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) | |
| processed_image = pipe.prepare_image_inputs([mask] * num_samples) | |
| p_params = replicate(params) | |
| prompt_ids = shard(prompt_ids) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| processed_image = shard(processed_image) | |
| print(processed_image[0].shape) | |
| output = pipe( | |
| prompt_ids=prompt_ids, | |
| image=processed_image, | |
| params=p_params, | |
| prng_seed=rng, | |
| num_inference_steps=50, | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) | |
| return output_images | |
| e_images = ['0.png', | |
| '0.png', | |
| '0.png', | |
| '0.png', | |
| '0.png', | |
| '2.png', | |
| '2.png', | |
| '2.png', | |
| '2.png',] | |
| e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east', | |
| 'a dog in the middle of the road, shadow on the ground,light direction north-west', | |
| 'a dog in the middle of the road, shadow on the ground,light direction south-west', | |
| 'a dog in the middle of the road, shadow on the ground,light direction south-east', | |
| 'a red rural house, shadow on the ground, light direction north', | |
| 'a red rural house, shadow on the ground, light direction east', | |
| 'a red rural house, shadow on the ground, light direction south', | |
| 'a red rural house, shadow on the ground, light direction west'] | |
| e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches', | |
| 'monochromatic, unrealistic, bad looking, full of glitches'] | |
| examples = [] | |
| for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts): | |
| examples.append([prompt, negative_prompt, image]) | |
| title = " # ControlLight: Light control through ControlNet and Depth Maps conditioning" | |
| info = ''' | |
| # ControlLight: Light control through ControlNet and Depth Maps conditioning | |
| We propose a ControlNet using depth maps conditioning that is capable of controlling the light direction in a scene while trying to maintain the scene integrity. | |
| The model was trained on [VIDIT dataset](https://huggingface.co/datasets/Nahrawy/VIDIT-Depth-ControlNet) and [ | |
| A Dataset of Flash and Ambient Illumination Pairs from the Crowd](https://huggingface.co/datasets/Nahrawy/FAID-Depth-ControlNet) as a part of the [Jax Diffusers Event](https://huggingface.co/jax-diffusers-event). | |
| Due to the limited available data the model is clearly overfit, but it serves as a proof of concept to what can be further achieved using enough data. | |
| A large part of the training data is synthetic so we encourage further training using synthetically generated scenes, using Unreal engine for example. | |
| The WandB training logs can be found [here](https://wandb.ai/hassanelnahrawy/controlnet-VIDIT-FAID), it's worth noting that the model was left to overfit for experimentation and it's advised to use the 8K steps weights or prior weights. | |
| This project is a joint work between [ParityError](https://huggingface.co/ParityError) and [Nahrawy](https://huggingface.co/Nahrawy). | |
| ''' | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| prompts = gr.Textbox(label='prompts') | |
| negative_prompts = gr.Textbox(label='negative_prompts') | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_image = gr.Image(label="Depth Map Conditioning") | |
| with gr.Column(): | |
| out_image = gr.Gallery(label="Generated Image") | |
| with gr.Row(): | |
| btn = gr.Button("Run") | |
| with gr.Row(): | |
| gr.Markdown(info) | |
| gr.Examples(examples=examples, | |
| inputs=[prompts,negative_prompts, in_image], | |
| outputs=out_image, | |
| fn=infer, | |
| cache_examples=True) | |
| btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image) | |
| demo.launch() |