foz commited on
Commit
8b57b0e
1 Parent(s): 883d009

Small fixes

Browse files
Files changed (2) hide show
  1. app.py +80 -0
  2. pose.jpg +0 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
10
+
11
+ def create_key(seed=0):
12
+ return jax.random.PRNGKey(seed)
13
+
14
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
15
+ "JFoz/dog-cat-pose", dtype=jnp.bfloat16
16
+ )
17
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
18
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
19
+ )
20
+
21
+ def infer(prompts, negative_prompts, image):
22
+ params["controlnet"] = controlnet_params
23
+
24
+ num_samples = 1 #jax.device_count()
25
+ rng = create_key(0)
26
+ rng = jax.random.split(rng, jax.device_count())
27
+ im = image
28
+ image = Image.fromarray(im)
29
+
30
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
31
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
32
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
33
+
34
+ p_params = replicate(params)
35
+ prompt_ids = shard(prompt_ids)
36
+ negative_prompt_ids = shard(negative_prompt_ids)
37
+ processed_image = shard(processed_image)
38
+
39
+ output = pipe(
40
+ prompt_ids=prompt_ids,
41
+ image=processed_image,
42
+ params=p_params,
43
+ prng_seed=rng,
44
+ num_inference_steps=50,
45
+ neg_prompt_ids=negative_prompt_ids,
46
+ jit=True,
47
+ ).images
48
+
49
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
50
+ return output_images
51
+
52
+ #gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
53
+
54
+ title = "Animal Pose Control Net"
55
+ description = "This is a demo of Animal Pose ControlNet, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning."
56
+
57
+ #with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])) as demo:
58
+ #gr.Markdown(
59
+ # """
60
+ # Animal Pose Control Net
61
+ # This is a demo of Animal Pose Control Net, which is a model trained on runwayml/stable-diffusion-v1-5 with new type of conditioning.
62
+ #""")
63
+
64
+ theme = gr.themes.Default(primary_hue="green").set(
65
+ button_primary_background_fill="*primary_200",
66
+ button_primary_background_fill_hover="*primary_300",
67
+ )
68
+
69
+ gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "gallery",
70
+ title = title, description = description, theme='gradio/soft',
71
+ examples=[["a Labrador crossing the road", "low quality", "pose.jpg"]]
72
+ ).launch()
73
+
74
+
75
+ gr.Markdown(
76
+ """
77
+ * [Dataset](https://huggingface.co/datasets/JFoz/dog-poses-controlnet-dataset)
78
+ * [Diffusers model](), [Web UI model](https://huggingface.co/JFoz/dog-pose)
79
+ * [Training Report](https://wandb.ai/john-fozard/dog-cat-pose/runs/kmwcvae5))
80
+ """)
pose.jpg ADDED