PrakhAI commited on
Commit
697c26b
·
1 Parent(s): ea32c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -22,24 +22,26 @@ def sample_latent(batch, key):
22
  def to_img(normalized):
23
  return ((normalized+1)*255./2.).astype(np.uint8)
24
 
 
 
 
 
25
  ROWS = 4
26
  COLUMNS = 4
27
- def generate_images(previous=None):
 
28
  unique_id = int(1_000_000 * time.time())
29
  latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
 
30
  if previous:
31
  latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
32
  (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
33
  img = np.array(to_img(g_out128))
34
  for row in range(ROWS):
35
- with st.container():
36
- for (col_idx, col) in enumerate(st.columns(COLUMNS)):
37
- with col:
38
- idx = row*COLUMNS + col_idx
39
- st.image(Image.fromarray(img[idx]))
40
- if st.button(label="Generate similar", key="%d_%d" % (unique_id, idx)):
41
- generate_images(latents[idx])
42
-
43
- st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
44
- if st.button('Generate Random'):
45
- generate_images()
 
22
  def to_img(normalized):
23
  return ((normalized+1)*255./2.).astype(np.uint8)
24
 
25
+ st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
26
+ if st.button('Generate Random'):
27
+ st.session_state['generate'] = None
28
+
29
  ROWS = 4
30
  COLUMNS = 4
31
+
32
+ if 'generate' in st.session_state:
33
  unique_id = int(1_000_000 * time.time())
34
  latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
35
+ previous = st.session_state['generate']
36
  if previous:
37
  latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
38
  (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
39
  img = np.array(to_img(g_out128))
40
  for row in range(ROWS):
41
+ with st.container():
42
+ for (col_idx, col) in enumerate(st.columns(COLUMNS)):
43
+ with col:
44
+ idx = row*COLUMNS + col_idx
45
+ st.image(Image.fromarray(img[idx]))
46
+ if st.button(label="Generate similar", key="%d_%d" % (unique_id, idx)):
47
+ st.session_state['generate'] = latents[idx]