File size: 4,251 Bytes
ab2369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import sys
import time
import torch
import os

import json
import argparse
sys.path.append(os.getcwd())
from diffusers import DDPMPipeline, DDIMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler

def main():
    parser = argparse.ArgumentParser(description="sampling script for CIFAR-10.")
    parser.add_argument('--test_num', type=int, default=1)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--num_inference_steps', type=int, default=20)
    parser.add_argument('--sampler_type', type = str,default='lag', choices=[ 'pndm', 'ddim', 'dpm++', 'dpm','dpm_lm', 'unipc'])
    parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
    parser.add_argument('--model_id', type=str,
                        default='/xxx/xxx/ddpm_ema_cifar10')
    parser.add_argument('--lamb', type=float, default=1.0)
    parser.add_argument('--kappa', type=float, default=0.0)
    parser.add_argument('--dtype', type=str, default='fp32')
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()

    dtype = None
    if args.dtype in ['fp32']:
        dtype = torch.float32
    elif args.dtype in ['fp64']:
        dtype = torch.float64
    elif args.dtype in ['fp16']:
        dtype = torch.float16
    elif args.dtype in ['bf16']:
        dtype = torch.bfloat16

    start_index = args.start_index
    device = args.device
    batch_size = args.batch_size
    sampler_type = args.sampler_type
    test_num = args.test_num
    num_inference_steps = args.num_inference_steps
    lamb = args.lamb
    kappa = args.kappa
    model_id = args.model_id

    save_dir = args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    with torch.no_grad():
        # load pipeline
        pipe = DDPMPipeline.from_pretrained(model_id, torch_dtype=dtype)
        pipe.unet.to(device)

        # load scheduler
        if sampler_type in ['pndm']:
            pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
        elif sampler_type in ['dpm++']:
            pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver++"
        elif sampler_type in ['dpm_lm']:
            pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver"
            pipe.scheduler.lamb = lamb
            pipe.scheduler.lm = True
            pipe.scheduler.kappa = kappa
        elif sampler_type in ['dpm']:
            pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
            pipe.scheduler.config.solver_order = 3
            pipe.scheduler.config.algorithm_type = "dpmsolver"
            pipe.scheduler.lm = False
        elif sampler_type in ['ddim']:
            pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        elif sampler_type in ['unipc']:
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

        for seed in range(start_index, start_index + test_num):
            print('prepare to sample')
            start_time = time.time()
            torch.manual_seed(seed)
            
            # sampling process
            images = pipe(batch_size=batch_size, num_inference_steps=num_inference_steps).images

            # store the generated images
            for i, image in enumerate(images):
                image.save(
                    os.path.join(save_dir, f"cifar10_{sampler_type}_inference{num_inference_steps}_seed{seed}_{i}.png"))
            print(f"{sampler_type} batch##{seed},done")

            # output the sampling time-costs
            end_time = time.time()
            time_difference = end_time - start_time
            print(f"The code took {time_difference} seconds to run.")

if __name__ == '__main__':
    main()