dmolino commited on
Commit
78b8f59
·
verified ·
1 Parent(s): 57b57ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -5
app.py CHANGED
@@ -19,7 +19,6 @@ def image_to_base64(image_path):
19
  st.markdown("""
20
  <style>
21
  @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
22
-
23
  /* Apply the font to everything */
24
  html, body, [class*="st"] {
25
  font-family: 'Roboto', sans-serif;
@@ -130,7 +129,15 @@ if 'generate' not in st.session_state:
130
 
131
  # Inizializza inference_tester solo una volta
132
  if 'inference_tester' not in st.session_state:
133
- st.session_state['inference_tester'] = 1
 
 
 
 
 
 
 
 
134
 
135
  # Usa inference_tester dalla sessione
136
  inference_tester = st.session_state['inference_tester']
@@ -202,12 +209,18 @@ if st.session_state['step'] == 2:
202
 
203
  # Pulsante per provare un esempio
204
  with col1:
 
 
 
 
 
 
205
  if st.button("Try an example"):
206
  st.session_state['step'] = 5 # Passa al passo 5
207
  st.rerun()
208
 
209
  # Pulsante per tornare all'inizio
210
- with col2:
211
  if st.button("Return to the beginning"):
212
  # Ripristina lo stato della sessione
213
  st.session_state['step'] = 1
@@ -365,8 +378,79 @@ if st.session_state['step'] == 3:
365
  st.rerun()
366
 
367
  if st.session_state['step'] == 4:
368
- st.write("Generation completed successfully!")
369
- st.session_state['generate'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  if st.button("Return to the beginning"):
372
  # Ripristina lo stato della sessione
 
19
  st.markdown("""
20
  <style>
21
  @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
 
22
  /* Apply the font to everything */
23
  html, body, [class*="st"] {
24
  font-family: 'Roboto', sans-serif;
 
129
 
130
  # Inizializza inference_tester solo una volta
131
  if 'inference_tester' not in st.session_state:
132
+ model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_video_diffuser_8frames.pth']
133
+ st.session_state['inference_tester'] = dani_model(model='thesis_model',
134
+ data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
135
+ pth=model_load_paths, load_weights=False)
136
+ inference_tester = st.session_state['inference_tester']
137
+
138
+ # Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
139
+ if 'weights_loaded' not in st.session_state:
140
+ st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
141
 
142
  # Usa inference_tester dalla sessione
143
  inference_tester = st.session_state['inference_tester']
 
209
 
210
  # Pulsante per provare un esempio
211
  with col1:
212
+ if st.button("Inference"):
213
+ st.session_state['step'] = 3 # Passa al passo 3
214
+ st.rerun()
215
+
216
+ # Pulsante per provare un esempio
217
+ with col2:
218
  if st.button("Try an example"):
219
  st.session_state['step'] = 5 # Passa al passo 5
220
  st.rerun()
221
 
222
  # Pulsante per tornare all'inizio
223
+ with col3:
224
  if st.button("Return to the beginning"):
225
  # Ripristina lo stato della sessione
226
  st.session_state['step'] = 1
 
378
  st.rerun()
379
 
380
  if st.session_state['step'] == 4:
381
+ # Costruzione del prompt
382
+ if st.session_state['generate'] is True:
383
+ conditioning = []
384
+ for inp in st.session_state['inputs']:
385
+ if inp == 'frontal':
386
+ cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
387
+ uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
388
+ encode_type='encode_vision').to(device)
389
+ conditioning.append(torch.cat([uim, cim]))
390
+ elif inp == 'lateral':
391
+ cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
392
+ uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
393
+ encode_type='encode_vision').to(device)
394
+ conditioning.append(torch.cat([uim, cim]))
395
+ elif inp == 'text':
396
+ ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
397
+ utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
398
+ conditioning.append(torch.cat([utx, ctx]))
399
+
400
+ # Costruzione delle shapes
401
+ shapes = []
402
+ for out in st.session_state['outputs']:
403
+ if out == 'frontal' or out == 'lateral':
404
+ shape = [1, 4, 256 // 8, 256 // 8]
405
+ shapes.append(shape)
406
+ elif out == 'text':
407
+ shape = [1, 768]
408
+ shapes.append(shape)
409
+
410
+ progress_bar = st.progress(0)
411
+
412
+ # Inferenza
413
+ z, _ = inference_tester.sampler.sample(
414
+ steps=50,
415
+ shape=shapes,
416
+ condition=conditioning,
417
+ unconditional_guidance_scale=7.5,
418
+ xtype=st.session_state['outputs'],
419
+ condition_types=st.session_state['inputs'],
420
+ eta=1,
421
+ verbose=False,
422
+ mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
423
+ progress_bar=progress_bar)
424
+
425
+ # Decoder e visualizzazione dei risultati
426
+ output_cols = st.columns(len(st.session_state['outputs']))
427
+
428
+ # Definire due colonne per le immagini
429
+ col1, col2 = st.columns(2)
430
+
431
+ # Iterare sugli output e assegnare le immagini alle colonne corrispondenti
432
+ for i, out in enumerate(st.session_state['outputs']):
433
+ if out == 'frontal':
434
+ x = inference_tester.net.autokl_decode(z[i])
435
+ x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
436
+ im = x[0].cpu().numpy()
437
+ with col1: # Mostrare la frontal image nella prima colonna
438
+ st.image(im, caption="Generated Frontal Image")
439
+ elif out == 'lateral':
440
+ x = inference_tester.net.autokl_decode(z[i])
441
+ x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
442
+ im = x[0].cpu().numpy()
443
+ with col2: # Mostrare la lateral image nella seconda colonna
444
+ st.image(im, caption="Generated Lateral Image")
445
+ elif out == 'text':
446
+ x = inference_tester.net.optimus_decode(z[i], max_length=100)
447
+ x = [a.tolist() for a in x]
448
+ rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
449
+ rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
450
+ st.write(f"Generated Report: {rec_text}")
451
+
452
+ st.write("Generation completed successfully!")
453
+ st.session_state['generate'] = False
454
 
455
  if st.button("Return to the beginning"):
456
  # Ripristina lo stato della sessione