sabman nabeelraza commited on
Commit
e01847c
1 Parent(s): cb38940

Use CUDA to generate images faster (#1)

Browse files

- Use CUDA to generate images faster (557e09045478b9d700094b3b07ea3a71f02269bf)


Co-authored-by: Muhammad Nabeel Raza <nabeelraza@users.noreply.huggingface.co>

Files changed (1) hide show
  1. inference_code.py +5 -2
inference_code.py CHANGED
@@ -5,13 +5,16 @@ from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
8
- pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
 
 
 
9
 
10
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
11
  def generate_images(prompt):
12
 
13
  prng_seed = jax.random.PRNGKey(-1)
14
- num_inference_steps = 50
15
 
16
  num_samples = jax.device_count()
17
  prompt = num_samples * [prompt]
 
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
8
+ # pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
+ pipeline = DiffusionPipeline.from_pretrained(
10
+ model_path,
11
+ from_flax=True, safety_checker=None).to("cuda")
12
 
13
  # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
14
  def generate_images(prompt):
15
 
16
  prng_seed = jax.random.PRNGKey(-1)
17
+ num_inference_steps = 20
18
 
19
  num_samples = jax.device_count()
20
  prompt = num_samples * [prompt]