File size: 2,303 Bytes
3f98781
5156e7a
 
 
 
3f98781
 
 
 
 
 
5156e7a
3f98781
 
5156e7a
 
 
 
 
3f98781
5156e7a
 
3f98781
5156e7a
 
3f98781
 
5156e7a
3f98781
 
5156e7a
 
3f98781
5156e7a
 
 
3f98781
5156e7a
 
 
 
 
 
 
 
3f98781
 
 
 
 
 
5156e7a
3f98781
 
 
5156e7a
3f98781
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# !pip install diffusers
import torch
from diffusers import DDIMPipeline, DDPMPipeline, PNDMPipeline
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
from diffusers import UNetUnconditionalModel
import gradio as gr
import PIL.Image
import numpy as np
import random

model_id = "google/ddpm-celebahq-256"
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder="unet")

# load model and scheduler
ddpm_scheduler = DDPMScheduler.from_config(model_id, subfolder="scheduler")
ddpm_pipeline = DDPMPipeline(unet=model, scheduler=ddpm_scheduler)

ddim_scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
ddim_pipeline = DDIMPipeline(unet=model, scheduler=ddim_scheduler)

pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler)
# run pipeline in inference (sample random noise and denoise)
def predict(seed=42,scheduler="ddim"):
    torch.cuda.empty_cache()
    generator = torch.manual_seed(seed)
    if(scheduler == "ddim"):
        image = ddim_pipeline(generator=generator, num_inference_steps=100)
        image = image["sample"]
    elif(scheduler == "ddpm"):
        image = ddpm_pipeline(generator=generator)
        #["sample"] doesnt work here for some reason
    elif(scheduler == "pndm"):
        image = pndm_pipeline(generator=generator, num_inference_steps=11)
        #["sample"] doesnt work here for some reason

    image_processed = image.cpu().permute(0, 2, 3, 1)
    if scheduler == "pndm":
        image_processed = (image_processed + 1.0) / 2
        image_processed = torch.clamp(image_processed, 0.0, 1.0)
        image_processed = image_processed * 255
    else:
        image_processed = (image_processed + 1.0) * 127.5
    image_processed = image_processed.detach().numpy().astype(np.uint8)
    return(PIL.Image.fromarray(image_processed[0]))


random_seed = random.randint(0, 2147483647)
gr.Interface(
    predict,
    inputs=[
        #gr.inputs.Slider(1, 1000, label='Inference Steps', default=20, step=1),
        gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed),
        gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler")
    ],
    outputs=gr.Image(shape=[256,256], type="pil"),
).launch()