Trace4SIRM2024 / app.py
aelius's picture
added diffusion model
7638e59
raw
history blame
745 Bytes
# Import convention
import streamlit as st
from diffusers import DiffusionPipeline
import matplotlib.pyplot as plt
import torch
organ = st.selectbox('Organ', ['Brain', 'Thorax'], index=None)
modality = st.selectbox('Modality', ['Magnetic Resonance Imaging', 'Computed Tomography'], index=None)
style = st.selectbox('Style', ['Picasso', 'Van Gogh'], index=None)
prompt_lst = [organ, modality, style]
if None not in prompt_lst:
prompt = ','.join(prompt_lst)
print(prompt)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
prompt += " high resolution, photorealistic"
image = pipe(prompt=prompt).images[0]
st.image(image)