gojiteji commited on
Commit
e4174c1
1 Parent(s): 68979d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax.jax_utils import replicate
2
+ from jax import pmap
3
+ from flax.training.common_utils import shard
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ from diffusers import FlaxStableDiffusionPipeline
12
+
13
+
14
+ import os
15
+ if 'TPU_NAME' in os.environ:
16
+ import requests
17
+ if 'TPU_DRIVER_MODE' not in globals():
18
+ url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
19
+ resp = requests.post(url)
20
+ TPU_DRIVER_MODE = 1
21
+
22
+
23
+ from jax.config import config
24
+ config.FLAGS.jax_xla_backend = "tpu_driver"
25
+ config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
26
+ print('Registered TPU:', config.FLAGS.jax_backend_target)
27
+ else:
28
+ print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
29
+
30
+ import jax
31
+ jax.local_devices()
32
+ num_devices = jax.device_count()
33
+ device_type = jax.devices()[0].device_kind
34
+
35
+ print(f"Found {num_devices} JAX devices of type {device_type}.")
36
+
37
+ def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ):
38
+ prng_seed = jax.random.PRNGKey(seed)
39
+ prompt_ids = pipeline.prepare_inputs(prompts)
40
+ params = replicate(params)
41
+ prng_seed = jax.random.split(prng_seed, jax.device_count())
42
+ prompt_ids = shard(prompt_ids)
43
+ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
44
+ images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
45
+ images = pipeline.numpy_to_pil(images)
46
+ return images
47
+ def image_grid(imgs, rows, cols, down_sample = 1 ):
48
+ w,h = imgs[0].size
49
+ grid = Image.new('RGB', size=(cols*w, rows*h))
50
+ for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
51
+ grid = grid.resize( (grid.size[0]//down_sample, grid.size[1]//down_sample) )
52
+ return grid
53
+
54
+
55
+ HF_ACCESS_TOKEN = os.environ["HFAUTH"]
56
+
57
+ # Load Model
58
+ # - Reference: https://github.com/huggingface/diffusers/blob/main/README.md
59
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
60
+ "CompVis/stable-diffusion-v1-4",
61
+ use_auth_token = HF_ACCESS_TOKEN,
62
+ revision="bf16",
63
+ dtype=jnp.bfloat16,
64
+ )