File size: 4,530 Bytes
ab2369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sys
import time
import torch
import os

import json
import argparse
sys.path.append(os.getcwd())
from diffusers import LDMPipeline, DDIMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler

def main():
    parser = argparse.ArgumentParser(description="sampling script for CelebA-HQ.")
    parser.add_argument('--test_num', type=int, default=1)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--num_inference_steps', type=int, default=20)
    parser.add_argument('--sampler_type', type = str,default='lag', choices=[ 'pndm', 'ddim_lm', 'ddim', 'dpm++', 'dpm','dpm_lm', 'unipc'])
    parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
    parser.add_argument('--model_id', type=str,
                        default='/xxx/xxx/ddpm_ema_cifar10')
    parser.add_argument('--lamb', type=float, default=1.0)
    parser.add_argument('--kappa', type=float, default=0.0)
    parser.add_argument('--dtype', type=str, default='fp32')
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()

    dtype = None
    if args.dtype in ['fp32']:
        dtype = torch.float32
    elif args.dtype in ['fp64']:
        dtype = torch.float64
    elif args.dtype in ['fp16']:
        dtype = torch.float16
    elif args.dtype in ['bf16']:
        dtype = torch.bfloat16

    start_index = args.start_index
    batch_size = args.batch_size
    sampler_type = args.sampler_type
    test_num = args.test_num
    num_inference_steps = args.num_inference_steps
    device = args.device
    lamb = args.lamb
    kappa = args.kappa
    model_id = args.model_id

    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        # load pipeline
        pipe = LDMPipeline.from_pretrained(model_id, torch_dtype=dtype)
        pipe.unet.to(device)
        pipe.vqvae.to(device)

        # load scheduler
        if sampler_type in ['pndm']:
            pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
        elif sampler_type in ['dpm++']:
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver++"
        elif sampler_type in ['dpm_lm']:
            pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver"
            pipe.scheduler.lamb = lamb
            pipe.scheduler.lm = True
            pipe.scheduler.kappa = kappa
        elif sampler_type in ['dpm']:
            pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver"
            pipe.scheduler.lm = False
        elif sampler_type in ['ddim']:
            pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        elif sampler_type in ['ddim_lm']:
            pipe.scheduler = DDIMLMScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.lamb = lamb
            pipe.scheduler.lm = True
            pipe.scheduler.kappa = kappa
        elif sampler_type in ['unipc']:
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

        for seed in range(start_index, start_index + test_num):
            print('prepare to sample')
            start_time = time.time()
            torch.manual_seed(seed)
            
            # sampling process
            images = pipe(batch_size=batch_size, num_inference_steps=num_inference_steps).images

            # store the generated images
            for i, image in enumerate(images):
                image.save(
                    os.path.join(save_dir, f"cifar10_{sampler_type}_inference{num_inference_steps}_seed{seed}_{i}.png"))
            print(f"{sampler_type} batch##{seed},done")

            # output the sampling time-costs
            end_time = time.time()
            time_difference = end_time - start_time
            print(f"The code took {time_difference} seconds to run.")

if __name__ == '__main__':
    main()