tsungtao commited on
Commit
adc523e
1 Parent(s): 7585942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -55
app.py CHANGED
@@ -1,60 +1,48 @@
1
  import gradio as gr
2
- from datasets import load_dataset
3
  import jax
4
  import numpy as np
5
  import jax.numpy as jnp
6
  from flax.jax_utils import replicate
7
  from flax.training.common_utils import shard
8
- #from diffusers.utils import load_image
9
- from diffusers.utils.testing_utils import load_image
10
  from PIL import Image
11
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
12
-
13
- def image_grid(imgs, rows, cols):
14
- w, h = imgs[0].size
15
- grid = Image.new("RGB", size=(cols * w, rows * h))
16
- for i, img in enumerate(imgs):
17
- grid.paste(img, box=(i % cols * w, i // cols * h))
18
- return grid
19
 
20
  def create_key(seed=0):
21
  return jax.random.PRNGKey(seed)
22
 
23
-
24
- def infer(prompt, negative_prompt, image):
25
- rng = create_key(0)
26
-
27
- # canny_image = load_image(
28
- # "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
29
- # )
30
- canny_image = load_image(image)
31
-
32
- #prompts = "a living room fan"
33
- prompts = prompt
34
- negative_prompts = negative_prompt
35
-
36
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
37
- "tsungtao/controlnet-mlsd-202305011046", from_flax=True, dtype=jnp.float32
38
- )
39
-
40
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
41
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
42
- )
43
-
44
  params["controlnet"] = controlnet_params
45
-
46
- num_samples = jax.device_count()
 
47
  rng = jax.random.split(rng, jax.device_count())
48
-
 
 
49
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
50
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
51
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
52
-
53
  p_params = replicate(params)
54
  prompt_ids = shard(prompt_ids)
55
  negative_prompt_ids = shard(negative_prompt_ids)
56
  processed_image = shard(processed_image)
57
-
58
  output = pipe(
59
  prompt_ids=prompt_ids,
60
  image=processed_image,
@@ -64,25 +52,8 @@ def infer(prompt, negative_prompt, image):
64
  neg_prompt_ids=negative_prompt_ids,
65
  jit=True,
66
  ).images
67
-
68
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
69
- output_images = image_grid(output_images, num_samples // 4, 4)
70
- #output_images.save("tao/image.png")
71
- #dataset = load_dataset('imagefolder', data_dir='tao')
72
- #dataset.push_to_hub('tsungtao/tmp')
73
  return output_images
74
 
75
- #infer('','','')
76
-
77
- def infer2(prompt, negative_prompt, image):
78
- output_image = infer(prompt, negative_prompt, image)
79
- #output_image = "https://datasets-server.huggingface.co/assets/tsungtao/tmp/--/tsungtao--tmp/train/0/image/image.jpg"
80
- return output_image
81
-
82
- title = "ControlNet on MLSD Filter"
83
- description = "This is a demo on ControlNet based on mlsd filter."
84
-
85
- #examples = [["living room with TV", "fan", "https://datasets-server.huggingface.co/assets/tsungtao/diffusers-testing/--/tsungtao--diffusers-testing/train/0/images/image.jpg"]]
86
-
87
- interface = gr.Interface(fn = infer2, inputs = ["text", "text", "text"], outputs = "image",title = title, description = description, theme='gradio/soft')
88
- interface.launch(enable_queue=True)
 
1
  import gradio as gr
 
2
  import jax
3
  import numpy as np
4
  import jax.numpy as jnp
5
  from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
 
 
7
  from PIL import Image
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
 
 
 
 
 
 
10
 
11
  def create_key(seed=0):
12
  return jax.random.PRNGKey(seed)
13
 
14
+ def canny_filter(image):
15
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
16
+ blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
17
+ edges_image = cv2.Canny(blurred_image, 50, 150)
18
+ return edges_image
19
+
20
+ # load control net and stable diffusion v1-5
21
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
22
+ "tsungtao/controlnet-mlsd-202305011046", from_flax=True, dtype=jnp.bfloat16
23
+ )
24
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
26
+ )
27
+
28
+ def infer(prompts, negative_prompts, image):
 
 
 
 
 
 
29
  params["controlnet"] = controlnet_params
30
+
31
+ num_samples = 1 #jax.device_count()
32
+ rng = create_key(0)
33
  rng = jax.random.split(rng, jax.device_count())
34
+ im = canny_filter(image)
35
+ canny_image = Image.fromarray(im)
36
+
37
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
38
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
40
+
41
  p_params = replicate(params)
42
  prompt_ids = shard(prompt_ids)
43
  negative_prompt_ids = shard(negative_prompt_ids)
44
  processed_image = shard(processed_image)
45
+
46
  output = pipe(
47
  prompt_ids=prompt_ids,
48
  image=processed_image,
 
52
  neg_prompt_ids=negative_prompt_ids,
53
  jit=True,
54
  ).images
55
+
56
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
 
 
 
 
57
  return output_images
58
 
59
+ gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()