王方懿康
Initial commit
ab2369a
import sys
import time
import torch
import os
import json
import argparse
sys.path.append(os.getcwd())
from diffusers import LDMPipeline, DDIMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler
def main():
parser = argparse.ArgumentParser(description="sampling script for CelebA-HQ.")
parser.add_argument('--test_num', type=int, default=1)
parser.add_argument('--start_index', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_inference_steps', type=int, default=20)
parser.add_argument('--sampler_type', type = str,default='lag', choices=[ 'pndm', 'ddim_lm', 'ddim', 'dpm++', 'dpm','dpm_lm', 'unipc'])
parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
parser.add_argument('--model_id', type=str,
default='/xxx/xxx/ddpm_ema_cifar10')
parser.add_argument('--lamb', type=float, default=1.0)
parser.add_argument('--kappa', type=float, default=0.0)
parser.add_argument('--dtype', type=str, default='fp32')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
dtype = None
if args.dtype in ['fp32']:
dtype = torch.float32
elif args.dtype in ['fp64']:
dtype = torch.float64
elif args.dtype in ['fp16']:
dtype = torch.float16
elif args.dtype in ['bf16']:
dtype = torch.bfloat16
start_index = args.start_index
batch_size = args.batch_size
sampler_type = args.sampler_type
test_num = args.test_num
num_inference_steps = args.num_inference_steps
device = args.device
lamb = args.lamb
kappa = args.kappa
model_id = args.model_id
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
with torch.no_grad():
# load pipeline
pipe = LDMPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.unet.to(device)
pipe.vqvae.to(device)
# load scheduler
if sampler_type in ['pndm']:
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
elif sampler_type in ['dpm++']:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver++"
elif sampler_type in ['dpm_lm']:
pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver"
pipe.scheduler.lamb = lamb
pipe.scheduler.lm = True
pipe.scheduler.kappa = kappa
elif sampler_type in ['dpm']:
pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver"
pipe.scheduler.lm = False
elif sampler_type in ['ddim']:
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
elif sampler_type in ['ddim_lm']:
pipe.scheduler = DDIMLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.lamb = lamb
pipe.scheduler.lm = True
pipe.scheduler.kappa = kappa
elif sampler_type in ['unipc']:
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
for seed in range(start_index, start_index + test_num):
print('prepare to sample')
start_time = time.time()
torch.manual_seed(seed)
# sampling process
images = pipe(batch_size=batch_size, num_inference_steps=num_inference_steps).images
# store the generated images
for i, image in enumerate(images):
image.save(
os.path.join(save_dir, f"cifar10_{sampler_type}_inference{num_inference_steps}_seed{seed}_{i}.png"))
print(f"{sampler_type} batch##{seed},done")
# output the sampling time-costs
end_time = time.time()
time_difference = end_time - start_time
print(f"The code took {time_difference} seconds to run.")
if __name__ == '__main__':
main()