merve HF staff commited on
Commit
192134b
1 Parent(s): 163ea2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -4,10 +4,25 @@ import numpy as np
4
  import jax.numpy as jnp
5
  from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
7
- from diffusers.utils import load_image
8
  from PIL import Image
9
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
10
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def image_grid(imgs, rows, cols):
13
  w, h = imgs[0].size
 
4
  import jax.numpy as jnp
5
  from flax.jax_utils import replicate
6
  from flax.training.common_utils import shard
 
7
  from PIL import Image
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
  import cv2
10
+ import os
11
+
12
+
13
+ def load_image(image):
14
+ if isinstance(image, str):
15
+ if image.startswith("http://") or image.startswith("https://"):
16
+ image = PIL.Image.open(requests.get(image, stream=True).raw)
17
+ elif os.path.isfile(image):
18
+ image = PIL.Image.open(image)
19
+ elif isinstance(image, PIL.Image.Image):
20
+ image = image
21
+ image = PIL.ImageOps.exif_transpose(image)
22
+ image = image.convert("RGB")
23
+ return image
24
+
25
+
26
 
27
  def image_grid(imgs, rows, cols):
28
  w, h = imgs[0].size