aelius commited on
Commit
17b483c
·
1 Parent(s): 9c45784

added submit button

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -1,21 +1,46 @@
1
- # Import convention
2
  import streamlit as st
 
3
  from diffusers import DiffusionPipeline
4
  import matplotlib.pyplot as plt
5
  import torch
6
 
 
 
 
 
 
 
 
7
  organ = st.selectbox('Organ', ['Brain', 'Thorax'], index=None)
8
  modality = st.selectbox('Modality', ['Magnetic Resonance Imaging', 'Computed Tomography'], index=None)
9
  style = st.selectbox('Style', ['Picasso', 'Van Gogh'], index=None)
10
 
11
  prompt_lst = [organ, modality, style]
 
12
  if None not in prompt_lst:
13
- prompt = ','.join(prompt_lst)
14
- print(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
17
- pipe.to("cuda")
 
 
 
 
 
18
 
19
- prompt += " high resolution, photorealistic"
20
- image = pipe(prompt=prompt).images[0]
21
- st.image(image)
 
 
1
  import streamlit as st
2
+ import time
3
  from diffusers import DiffusionPipeline
4
  import matplotlib.pyplot as plt
5
  import torch
6
 
7
+ if 'button_clicked' not in st.session_state:
8
+ st.session_state.button_clicked = False
9
+
10
+ # Define a function to handle the button click
11
+ def on_button_click():
12
+ st.session_state.button_clicked = True
13
+
14
  organ = st.selectbox('Organ', ['Brain', 'Thorax'], index=None)
15
  modality = st.selectbox('Modality', ['Magnetic Resonance Imaging', 'Computed Tomography'], index=None)
16
  style = st.selectbox('Style', ['Picasso', 'Van Gogh'], index=None)
17
 
18
  prompt_lst = [organ, modality, style]
19
+
20
  if None not in prompt_lst:
21
+ st.session_state.button_disabled = False
22
+ else:
23
+ st.session_state.button_disabled = True
24
+
25
+
26
+ if st.session_state.button_clicked:
27
+ st.session_state.button_disabled = True
28
+ st.session_state.button_clicked = False
29
+ st.button('Submit', disabled=st.session_state.button_disabled)
30
+ with st.spinner('Processing...'):
31
+ prompt = ','.join(prompt_lst)
32
+ print(prompt)
33
+
34
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
35
+ pipe.to("cuda")
36
 
37
+ prompt += " high resolution, photorealistic"
38
+ image = pipe(prompt=prompt).images[0]
39
+
40
+ st.image(image)
41
+
42
+ st.session_state.button_disabled = False
43
+
44
 
45
+ else:
46
+ st.button('Submit', on_click=on_button_click, disabled=st.session_state.button_disabled)