saburq commited on
Commit
6a64212
1 Parent(s): 1cbc96b

try moving everything to inference funtion

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. inference_code.py +8 -7
app.py CHANGED
@@ -9,8 +9,8 @@ def generate_image_predictions(prompt):
9
 
10
  iface = gr.Interface(
11
  fn=generate_image_predictions,
12
- inputs=gr.components.Textbox("Enter a text prompt here"),
13
- outputs=[gr.components.Image() for i in range(4)],
14
  title="Map Diffuser",
15
  description="Generates four images from a given text prompt.",
16
  examples=[["Satellite image of amsterdam with industrial area and highways"], [
 
9
 
10
  iface = gr.Interface(
11
  fn=generate_image_predictions,
12
+ inputs=gr.components.Textbox(label="Enter a text prompt here"),
13
+ outputs=[gr.components.Image(label="Output Image") for i in range(4)],
14
  title="Map Diffuser",
15
  description="Generates four images from a given text prompt.",
16
  examples=[["Satellite image of amsterdam with industrial area and highways"], [
inference_code.py CHANGED
@@ -4,24 +4,25 @@ from flax.jax_utils import replicate
4
  from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
- model_path = "sabman/map-diffuser-v3"
8
- pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
 
10
- #prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
11
  def generate_images(prompt):
 
 
 
12
  prng_seed = jax.random.PRNGKey(-1)
13
  num_inference_steps = 50
14
-
15
  num_samples = jax.device_count()
16
  prompt = num_samples * [prompt]
17
  prompt_ids = pipeline.prepare_inputs(prompt)
18
-
19
  # shard inputs and rng
20
  params = replicate(_params)
21
  prng_seed = jax.random.split(prng_seed, jax.device_count())
22
  prompt_ids = shard(prompt_ids)
23
-
24
  images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
25
  images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
26
- return images
27
 
 
 
4
  from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
 
 
7
 
8
+ # prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
9
  def generate_images(prompt):
10
+ model_path = "sabman/map-diffuser-v3"
11
+ pipeline, _params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
12
+
13
  prng_seed = jax.random.PRNGKey(-1)
14
  num_inference_steps = 50
15
+
16
  num_samples = jax.device_count()
17
  prompt = num_samples * [prompt]
18
  prompt_ids = pipeline.prepare_inputs(prompt)
19
+
20
  # shard inputs and rng
21
  params = replicate(_params)
22
  prng_seed = jax.random.split(prng_seed, jax.device_count())
23
  prompt_ids = shard(prompt_ids)
24
+
25
  images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
26
  images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
 
27
 
28
+ return [images[0], images[1], images[2], images[3]]