Update app.py
Browse files
app.py
CHANGED
@@ -22,24 +22,26 @@ def sample_latent(batch, key):
|
|
22 |
def to_img(normalized):
|
23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
24 |
|
|
|
|
|
|
|
|
|
25 |
ROWS = 4
|
26 |
COLUMNS = 4
|
27 |
-
|
|
|
28 |
unique_id = int(1_000_000 * time.time())
|
29 |
latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
|
|
|
30 |
if previous:
|
31 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
32 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
33 |
img = np.array(to_img(g_out128))
|
34 |
for row in range(ROWS):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
44 |
-
if st.button('Generate Random'):
|
45 |
-
generate_images()
|
|
|
22 |
def to_img(normalized):
|
23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
24 |
|
25 |
+
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
26 |
+
if st.button('Generate Random'):
|
27 |
+
st.session_state['generate'] = None
|
28 |
+
|
29 |
ROWS = 4
|
30 |
COLUMNS = 4
|
31 |
+
|
32 |
+
if 'generate' in st.session_state:
|
33 |
unique_id = int(1_000_000 * time.time())
|
34 |
latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
|
35 |
+
previous = st.session_state['generate']
|
36 |
if previous:
|
37 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
38 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
39 |
img = np.array(to_img(g_out128))
|
40 |
for row in range(ROWS):
|
41 |
+
with st.container():
|
42 |
+
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|
43 |
+
with col:
|
44 |
+
idx = row*COLUMNS + col_idx
|
45 |
+
st.image(Image.fromarray(img[idx]))
|
46 |
+
if st.button(label="Generate similar", key="%d_%d" % (unique_id, idx)):
|
47 |
+
st.session_state['generate'] = latents[idx]
|
|
|
|
|
|
|
|