sushi-diffusion / app.py
shionhonda's picture
change
bbce684
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import streamlit as st
import torch
def get_model():
unet = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
)
model = GaussianDiffusion(
unet,
image_size = 64,
timesteps = 1000, # number of steps
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
loss_type = 'l1' , # L1 or L2
p2_loss_weight_gamma = 1.
)
model.load_state_dict(torch.load("./model-final.pt", map_location="cpu"))
model.eval()
return model
def scale_to_255(x):
return ((x+1)/2*255).astype('uint8')
if __name__ == "__main__":
st.title("Sushi Diffusion")
st.text("The generation process takes about 10 mins.")
st.text("If you don't want to wait, please visit: https://thissushidoesnotexist.com/")
model = get_model()
st.text('Press the button below to generate sushi!')
if st.button('🍣'):
bar = st.progress(0)
img = torch.randn((1,3,64,64), device="cpu")
for t in reversed(range(0, model.num_timesteps)):
img, _ = model.p_sample(img, t, None)
bar.progress((model.num_timesteps-t) / model.num_timesteps)
img = scale_to_255(img.squeeze().numpy().transpose(1,2,0))
st.image(img, caption='This sushi does not exist.')