sabman commited on
Commit
9f788a3
1 Parent(s): e84224a

Update inference_code.py

Browse files
Files changed (1) hide show
  1. inference_code.py +1 -16
inference_code.py CHANGED
@@ -5,7 +5,6 @@ from flax.training.common_utils import shard
5
  from diffusers import DiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
8
- # pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
  pipeline = DiffusionPipeline.from_pretrained(
10
  model_path,
11
  from_flax=True, safety_checker=None).to("cuda")
@@ -16,20 +15,6 @@ def generate_images(prompt):
16
  prng_seed = jax.random.PRNGKey(-1)
17
  num_inference_steps = 20
18
 
19
- images = pipeline(prompt, width=512, num_inference_steps=20, num_images_per_prompt=1).images
20
- # images = pipeline.numpy_to_pil(np.asarray(images.reshape((1,) + images.shape[-3:])))
21
-
22
 
23
- # num_samples = jax.device_count()
24
- # prompt = num_samples * [prompt]
25
- # prompt_ids = pipeline.prepare_inputs(prompt)
26
-
27
- # # shard inputs and rng
28
- # params = replicate(_params)
29
- # prng_seed = jax.random.split(prng_seed, jax.device_count())
30
- # prompt_ids = shard(prompt_ids)
31
-
32
- # images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
33
- # images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
34
-
35
  return images[0]
 
5
  from diffusers import DiffusionPipeline
6
 
7
  model_path = "sabman/map-diffuser-v3"
 
8
  pipeline = DiffusionPipeline.from_pretrained(
9
  model_path,
10
  from_flax=True, safety_checker=None).to("cuda")
 
15
  prng_seed = jax.random.PRNGKey(-1)
16
  num_inference_steps = 20
17
 
18
+ images = pipeline(prompt, width=512, num_inference_steps=num_inference_steps, num_images_per_prompt=1).images
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return images[0]