merve HF staff commited on
Commit
72bed91
1 Parent(s): 58bad2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -36,9 +36,6 @@ def create_key(seed=0):
36
  return jax.random.PRNGKey(seed)
37
 
38
 
39
- rng = create_key(0)
40
-
41
-
42
 
43
 
44
  def canny_filter(image):
@@ -48,19 +45,21 @@ def canny_filter(image):
48
  edges_image = cv2.Canny(blurred_image, 50, 150)
49
  return edges_image
50
 
 
 
 
 
 
 
 
 
51
  def infer(prompts, negative_prompts, image):
52
 
53
 
54
- # load control net and stable diffusion v1-5
55
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
56
- "jax-diffusers-event/canny-coyo1m", dtype=jnp.float32
57
- )
58
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
59
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
60
- )
61
  params["controlnet"] = controlnet_params
62
 
63
  num_samples = jax.device_count()
 
64
  rng = jax.random.split(rng, jax.device_count())
65
  canny_image = canny_filter(image)
66
 
 
36
  return jax.random.PRNGKey(seed)
37
 
38
 
 
 
 
39
 
40
 
41
  def canny_filter(image):
 
45
  edges_image = cv2.Canny(blurred_image, 50, 150)
46
  return edges_image
47
 
48
+ # load control net and stable diffusion v1-5
49
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
50
+ "jax-diffusers-event/canny-coyo1m", dtype=jnp.float32
51
+ )
52
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
53
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
54
+ )
55
+
56
  def infer(prompts, negative_prompts, image):
57
 
58
 
 
 
 
 
 
 
 
59
  params["controlnet"] = controlnet_params
60
 
61
  num_samples = jax.device_count()
62
+ rng = create_key(0)
63
  rng = jax.random.split(rng, jax.device_count())
64
  canny_image = canny_filter(image)
65