Update app.py
Browse files
app.py
CHANGED
@@ -19,7 +19,6 @@ def image_to_base64(image_path):
|
|
19 |
st.markdown("""
|
20 |
<style>
|
21 |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
|
22 |
-
|
23 |
/* Apply the font to everything */
|
24 |
html, body, [class*="st"] {
|
25 |
font-family: 'Roboto', sans-serif;
|
@@ -130,7 +129,15 @@ if 'generate' not in st.session_state:
|
|
130 |
|
131 |
# Inizializza inference_tester solo una volta
|
132 |
if 'inference_tester' not in st.session_state:
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Usa inference_tester dalla sessione
|
136 |
inference_tester = st.session_state['inference_tester']
|
@@ -202,12 +209,18 @@ if st.session_state['step'] == 2:
|
|
202 |
|
203 |
# Pulsante per provare un esempio
|
204 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if st.button("Try an example"):
|
206 |
st.session_state['step'] = 5 # Passa al passo 5
|
207 |
st.rerun()
|
208 |
|
209 |
# Pulsante per tornare all'inizio
|
210 |
-
with
|
211 |
if st.button("Return to the beginning"):
|
212 |
# Ripristina lo stato della sessione
|
213 |
st.session_state['step'] = 1
|
@@ -365,8 +378,79 @@ if st.session_state['step'] == 3:
|
|
365 |
st.rerun()
|
366 |
|
367 |
if st.session_state['step'] == 4:
|
368 |
-
|
369 |
-
st.session_state['generate']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
if st.button("Return to the beginning"):
|
372 |
# Ripristina lo stato della sessione
|
|
|
19 |
st.markdown("""
|
20 |
<style>
|
21 |
@import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
|
|
|
22 |
/* Apply the font to everything */
|
23 |
html, body, [class*="st"] {
|
24 |
font-family: 'Roboto', sans-serif;
|
|
|
129 |
|
130 |
# Inizializza inference_tester solo una volta
|
131 |
if 'inference_tester' not in st.session_state:
|
132 |
+
model_load_paths = ['CoDi_encoders.pth', 'CoDi_text_diffuser.pth', 'CoDi_video_diffuser_8frames.pth']
|
133 |
+
st.session_state['inference_tester'] = dani_model(model='thesis_model',
|
134 |
+
data_dir='/mimer/NOBACKUP/groups/snic2022-5-277/dmolino/checkpoints/',
|
135 |
+
pth=model_load_paths, load_weights=False)
|
136 |
+
inference_tester = st.session_state['inference_tester']
|
137 |
+
|
138 |
+
# Caricamento dei pesi Clip, Optimus, Frontal, Lateral e Text una sola volta
|
139 |
+
if 'weights_loaded' not in st.session_state:
|
140 |
+
st.session_state['weights_loaded'] = True # Indica che i pesi sono stati caricati
|
141 |
|
142 |
# Usa inference_tester dalla sessione
|
143 |
inference_tester = st.session_state['inference_tester']
|
|
|
209 |
|
210 |
# Pulsante per provare un esempio
|
211 |
with col1:
|
212 |
+
if st.button("Inference"):
|
213 |
+
st.session_state['step'] = 3 # Passa al passo 3
|
214 |
+
st.rerun()
|
215 |
+
|
216 |
+
# Pulsante per provare un esempio
|
217 |
+
with col2:
|
218 |
if st.button("Try an example"):
|
219 |
st.session_state['step'] = 5 # Passa al passo 5
|
220 |
st.rerun()
|
221 |
|
222 |
# Pulsante per tornare all'inizio
|
223 |
+
with col3:
|
224 |
if st.button("Return to the beginning"):
|
225 |
# Ripristina lo stato della sessione
|
226 |
st.session_state['step'] = 1
|
|
|
378 |
st.rerun()
|
379 |
|
380 |
if st.session_state['step'] == 4:
|
381 |
+
# Costruzione del prompt
|
382 |
+
if st.session_state['generate'] is True:
|
383 |
+
conditioning = []
|
384 |
+
for inp in st.session_state['inputs']:
|
385 |
+
if inp == 'frontal':
|
386 |
+
cim = inference_tester.net.clip_encode_vision(st.session_state['frontal'], encode_type='encode_vision').to(device)
|
387 |
+
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['frontal']).to(device),
|
388 |
+
encode_type='encode_vision').to(device)
|
389 |
+
conditioning.append(torch.cat([uim, cim]))
|
390 |
+
elif inp == 'lateral':
|
391 |
+
cim = inference_tester.net.clip_encode_vision(st.session_state['lateral'], encode_type='encode_vision').to(device)
|
392 |
+
uim = inference_tester.net.clip_encode_vision(torch.zeros_like(st.session_state['lateral']).to(device),
|
393 |
+
encode_type='encode_vision').to(device)
|
394 |
+
conditioning.append(torch.cat([uim, cim]))
|
395 |
+
elif inp == 'text':
|
396 |
+
ctx = inference_tester.net.clip_encode_text(1 * [st.session_state['report']], encode_type='encode_text').to(device)
|
397 |
+
utx = inference_tester.net.clip_encode_text(1 * [""], encode_type='encode_text').to(device)
|
398 |
+
conditioning.append(torch.cat([utx, ctx]))
|
399 |
+
|
400 |
+
# Costruzione delle shapes
|
401 |
+
shapes = []
|
402 |
+
for out in st.session_state['outputs']:
|
403 |
+
if out == 'frontal' or out == 'lateral':
|
404 |
+
shape = [1, 4, 256 // 8, 256 // 8]
|
405 |
+
shapes.append(shape)
|
406 |
+
elif out == 'text':
|
407 |
+
shape = [1, 768]
|
408 |
+
shapes.append(shape)
|
409 |
+
|
410 |
+
progress_bar = st.progress(0)
|
411 |
+
|
412 |
+
# Inferenza
|
413 |
+
z, _ = inference_tester.sampler.sample(
|
414 |
+
steps=50,
|
415 |
+
shape=shapes,
|
416 |
+
condition=conditioning,
|
417 |
+
unconditional_guidance_scale=7.5,
|
418 |
+
xtype=st.session_state['outputs'],
|
419 |
+
condition_types=st.session_state['inputs'],
|
420 |
+
eta=1,
|
421 |
+
verbose=False,
|
422 |
+
mix_weight={'lateral': 1, 'text': 1, 'frontal': 1},
|
423 |
+
progress_bar=progress_bar)
|
424 |
+
|
425 |
+
# Decoder e visualizzazione dei risultati
|
426 |
+
output_cols = st.columns(len(st.session_state['outputs']))
|
427 |
+
|
428 |
+
# Definire due colonne per le immagini
|
429 |
+
col1, col2 = st.columns(2)
|
430 |
+
|
431 |
+
# Iterare sugli output e assegnare le immagini alle colonne corrispondenti
|
432 |
+
for i, out in enumerate(st.session_state['outputs']):
|
433 |
+
if out == 'frontal':
|
434 |
+
x = inference_tester.net.autokl_decode(z[i])
|
435 |
+
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
436 |
+
im = x[0].cpu().numpy()
|
437 |
+
with col1: # Mostrare la frontal image nella prima colonna
|
438 |
+
st.image(im, caption="Generated Frontal Image")
|
439 |
+
elif out == 'lateral':
|
440 |
+
x = inference_tester.net.autokl_decode(z[i])
|
441 |
+
x = torch.clamp((x[0] + 1.0) / 2.0, min=0.0, max=1.0)
|
442 |
+
im = x[0].cpu().numpy()
|
443 |
+
with col2: # Mostrare la lateral image nella seconda colonna
|
444 |
+
st.image(im, caption="Generated Lateral Image")
|
445 |
+
elif out == 'text':
|
446 |
+
x = inference_tester.net.optimus_decode(z[i], max_length=100)
|
447 |
+
x = [a.tolist() for a in x]
|
448 |
+
rec_text = [inference_tester.net.optimus.tokenizer_decoder.decode(a) for a in x]
|
449 |
+
rec_text = rec_text[0].replace('<BOS>', '').replace('<EOS>', '')
|
450 |
+
st.write(f"Generated Report: {rec_text}")
|
451 |
+
|
452 |
+
st.write("Generation completed successfully!")
|
453 |
+
st.session_state['generate'] = False
|
454 |
|
455 |
if st.button("Return to the beginning"):
|
456 |
# Ripristina lo stato della sessione
|