Spaces:
Runtime error
Runtime error
from denoising_diffusion_pytorch import Unet, GaussianDiffusion | |
import streamlit as st | |
import torch | |
def get_model(): | |
unet = Unet( | |
dim = 64, | |
dim_mults = (1, 2, 4, 8) | |
) | |
model = GaussianDiffusion( | |
unet, | |
image_size = 64, | |
timesteps = 1000, # number of steps | |
sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) | |
loss_type = 'l1' , # L1 or L2 | |
p2_loss_weight_gamma = 1. | |
) | |
model.load_state_dict(torch.load("./model-final.pt", map_location="cpu")) | |
model.eval() | |
return model | |
def scale_to_255(x): | |
return ((x+1)/2*255).astype('uint8') | |
if __name__ == "__main__": | |
st.title("Sushi Diffusion") | |
st.text("The generation process takes about 10 mins.") | |
st.text("If you don't want to wait, please visit: https://thissushidoesnotexist.com/") | |
model = get_model() | |
st.text('Press the button below to generate sushi!') | |
if st.button('🍣'): | |
bar = st.progress(0) | |
img = torch.randn((1,3,64,64), device="cpu") | |
for t in reversed(range(0, model.num_timesteps)): | |
img, _ = model.p_sample(img, t, None) | |
bar.progress((model.num_timesteps-t) / model.num_timesteps) | |
img = scale_to_255(img.squeeze().numpy().transpose(1,2,0)) | |
st.image(img, caption='This sushi does not exist.') | |