Edit model card

ControlLight: Light control through ControlNet and Depth Maps conditioning

We propose a ControlNet using depth maps conditioning that is capable of controlling the light direction in a scene while trying to maintain the scene integrity. The model was trained on VIDIT dataset and A Dataset of Flash and Ambient Illumination Pairs from the Crowd as a part of the Jax Diffusers Event.

Due to the limited available data the model is clearly overfit, but it serves as a proof of concept to what can be further achieved using enough data.

A large part of the training data is synthetic so we encourage further training using synthetically generated scenes, using Unreal engine for example.

The WandB training logs can be found here, it's worth noting that the model was left to overfit for experimentation and it's advised to use the 8K steps weights or prior weights.

This project is a joint work between ParityError and Nahrawy.

To use model, the following code can be used

import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def process_mask(image):
    mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    mask = cv2.resize(mask,(512,512))
    return mask

# load control net and stable diffusion v1-5
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def infer(prompts, negative_prompts, image):
    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    im = process_mask(image)
    mask = Image.fromarray(im)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([mask] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print(processed_image[0].shape)
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=50,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    
    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images
Downloads last month
70

Adapter for

Space using Nahrawy/controlnet-VIDIT-FAID 1