Spaces:
Build error
Build error
File size: 5,724 Bytes
a46c388 12c6f34 5832738 a46c388 5832738 13f1240 8d1c7e2 5832738 8d1c7e2 a46c388 ed800ac ec091a3 a46c388 ec091a3 a46c388 12c6f34 a46c388 de0c08b 4c873c8 a46c388 02b3506 a46c388 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def process_mask(image):
mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
mask = cv2.resize(mask,(512,512))
return mask
def infer(prompts, negative_prompts, image):
params["controlnet"] = controlnet_params
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
im = process_mask(image)
mask = Image.fromarray(im)
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([mask] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
print(processed_image[0].shape)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
e_images = ['0.png',
'0.png',
'0.png',
'0.png',
'0.png',
'2.png',
'2.png',
'2.png',
'2.png',]
e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east',
'a dog in the middle of the road, shadow on the ground,light direction north-west',
'a dog in the middle of the road, shadow on the ground,light direction south-west',
'a dog in the middle of the road, shadow on the ground,light direction south-east',
'a red rural house, shadow on the ground, light direction north',
'a red rural house, shadow on the ground, light direction east',
'a red rural house, shadow on the ground, light direction south',
'a red rural house, shadow on the ground, light direction west']
e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches',
'monochromatic, unrealistic, bad looking, full of glitches']
examples = []
for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts):
examples.append([prompt, negative_prompt, image])
title = " # ControlLight: Light control through ControlNet and Depth Maps conditioning"
info = '''
# ControlLight: Light control through ControlNet and Depth Maps conditioning
We propose a ControlNet using depth maps conditioning that is capable of controlling the light direction in a scene while trying to maintain the scene integrity.
The model was trained on [VIDIT dataset](https://huggingface.co/datasets/Nahrawy/VIDIT-Depth-ControlNet) and [
A Dataset of Flash and Ambient Illumination Pairs from the Crowd](https://huggingface.co/datasets/Nahrawy/FAID-Depth-ControlNet) as a part of the [Jax Diffusers Event](https://huggingface.co/jax-diffusers-event).
Due to the limited available data the model is clearly overfit, but it serves as a proof of concept to what can be further achieved using enough data.
A large part of the training data is synthetic so we encourage further training using synthetically generated scenes, using Unreal engine for example.
The WandB training logs can be found [here](https://wandb.ai/hassanelnahrawy/controlnet-VIDIT-FAID), it's worth noting that the model was left to overfit for experimentation and it's advised to use the 8K steps weights or prior weights.
This project is a joint work between [ParityError](https://huggingface.co/ParityError) and [Nahrawy](https://huggingface.co/Nahrawy).
'''
with gr.Blocks() as demo:
gr.Markdown(title)
prompts = gr.Textbox(label='prompts')
negative_prompts = gr.Textbox(label='negative_prompts')
with gr.Row():
with gr.Column():
in_image = gr.Image(label="Depth Map Conditioning")
with gr.Column():
out_image = gr.Gallery(label="Generated Image")
with gr.Row():
btn = gr.Button("Run")
with gr.Row():
gr.Markdown(info)
gr.Examples(examples=examples,
inputs=[prompts,negative_prompts, in_image],
outputs=out_image,
fn=infer,
cache_examples=True)
btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image)
demo.launch() |