Use CUDA to generate images faster

#1
Files changed (1) hide show
  1. inference_code.py +5 -2
inference_code.py CHANGED
@@ -5,13 +5,16 @@ from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
8
- pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
 
 
 
9
 
10
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
11
  def generate_images(prompt):
12
 
13
  prng_seed = jax.random.PRNGKey(-1)
14
- num_inference_steps = 50
15
 
16
  num_samples = jax.device_count()
17
  prompt = num_samples * [prompt]
 
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
8
+ # pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
+ pipeline = DiffusionPipeline.from_pretrained(
10
+ model_path,
11
+ from_flax=True, safety_checker=None).to("cuda")
12
 
13
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
14
  def generate_images(prompt):
15
 
16
  prng_seed = jax.random.PRNGKey(-1)
17
+ num_inference_steps = 20
18
 
19
  num_samples = jax.device_count()
20
  prompt = num_samples * [prompt]