Nahrawy commited on
Commit
a46c388
1 Parent(s): 9be3344

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
10
+
11
+ # load control net and stable diffusion v1-5
12
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
13
+ "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd"
14
+ )
15
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
16
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
17
+ )
18
+
19
+ def create_key(seed=0):
20
+ return jax.random.PRNGKey(seed)
21
+
22
+ def process_mask(image):
23
+ mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
24
+ mask = cv2.resize(mask,(512,512))
25
+ return mask
26
+
27
+
28
+
29
+ def infer(prompts, negative_prompts, image):
30
+ params["controlnet"] = controlnet_params
31
+
32
+ num_samples = 1 #jax.device_count()
33
+ rng = create_key(0)
34
+ rng = jax.random.split(rng, jax.device_count())
35
+ im = process_mask(image)
36
+ mask = Image.fromarray(im)
37
+
38
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
39
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
40
+ processed_image = pipe.prepare_image_inputs([mask] * num_samples)
41
+
42
+ p_params = replicate(params)
43
+ prompt_ids = shard(prompt_ids)
44
+ negative_prompt_ids = shard(negative_prompt_ids)
45
+ processed_image = shard(processed_image)
46
+ print(processed_image[0].shape)
47
+ output = pipe(
48
+ prompt_ids=prompt_ids,
49
+ image=processed_image,
50
+ params=p_params,
51
+ prng_seed=rng,
52
+ num_inference_steps=50,
53
+ neg_prompt_ids=negative_prompt_ids,
54
+ jit=True,
55
+ ).images
56
+
57
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
58
+ return output_images
59
+
60
+ e_images = ['examples/0.png',
61
+ 'examples/1.png'
62
+ 'examples/2.png']
63
+ e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east',
64
+ 'a skyscraper in the middle of an intersection, shadow on the ground, light direction east',
65
+ 'a red rural house, light temperature 5500, shadow on the ground, light direction south-west']
66
+ e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches'*3]
67
+ examples = []
68
+ for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts):
69
+ examples.append([prompt, negative_prompt, image])
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown(title)
73
+ prompts = gr.Textbox(label='prompts')
74
+ negative_prompts = gr.Textbox(label='negative_prompts')
75
+ with gr.Row():
76
+ with gr.Column():
77
+ in_image = gr.Image(label="Depth Map Conditioning")
78
+ with gr.Column():
79
+ out_image = gr.Image(label="Generated Image")
80
+ with gr.Row():
81
+ btn = gr.Button("Run")
82
+ gr.Examples(examples=examples,
83
+ inputs=[prompts,negative_prompts, in_image],
84
+ outputs=out_image,
85
+ cache_examples=True)
86
+ btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image)
87
+
88
+ demo.launch()