saburq commited on
Commit
f205d49
1 Parent(s): 2c655c4
Files changed (3) hide show
  1. app.py +17 -0
  2. inference_code.py +27 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference_code import generate_images
3
+
4
+ def generate_image_predictions(prompt):
5
+ images = generate_images(prompt)
6
+ return images
7
+
8
+ iface = gr.Interface(
9
+ fn=generate_image_predictions,
10
+ inputs=gr.inputs.Textbox("Enter a text prompt here"),
11
+ outputs=[gr.outputs.Image() for i in range(4)],
12
+ title="Map Diffuser",
13
+ description="Generates four images from a given text prompt.",
14
+ examples=[["Satellite image of amsterdam with industrial area and highways"], ["Satellite image with forests and residential, no water"], ["A person playing guitar on a stage"]]
15
+ )
16
+
17
+ iface.launch()
inference_code.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import numpy as np
3
+ from flax.jax_utils import replicate
4
+ from flax.training.common_utils import shard
5
+ from diffusers import FlaxStableDiffusionPipeline
6
+
7
+ model_path = "sabman/map-diffuser-v3"
8
+ pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
9
+
10
+ #prompt = "create a map with traffic signals, busway and residential buildings, in water color style"
11
+ def generate_images(prompt):
12
+ prng_seed = jax.random.PRNGKey(-1)
13
+ num_inference_steps = 50
14
+
15
+ num_samples = jax.device_count()
16
+ prompt = num_samples * [prompt]
17
+ prompt_ids = pipeline.prepare_inputs(prompt)
18
+
19
+ # shard inputs and rng
20
+ params = replicate(params)
21
+ prng_seed = jax.random.split(prng_seed, jax.device_count())
22
+ prompt_ids = shard(prompt_ids)
23
+
24
+ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
25
+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
26
+ return images
27
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.25.1
2
+ datasets
3
+ flax
4
+ optax
5
+ torch
6
+ torchvision
7
+ ftfy
8
+ tensorboard
9
+ Jinja2