File size: 6,563 Bytes
ab2369a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import sys
import torch
import os
import json
import argparse
sys.path.append(os.getcwd())
from diffusers import StableDiffusionPipeline, PNDMScheduler, UniPCMultistepScheduler, DDIMScheduler, DiffusionPipeline, PixArtAlphaPipeline
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="sampling script for T2I-Bench.")
parser.add_argument('--test_num', type=int, default=10)
parser.add_argument('--start_index', type=int, default=0)
parser.add_argument('--num_inference_steps', type=int, default=10)
parser.add_argument('--guidance', type=float, default=7.5)
parser.add_argument('--sampler_type', type = str, default='dpm_lm')
parser.add_argument('--model', type=str, default='sd15', choices=['sd15', 'sd2_base', 'sdxl', 'pixart'])
parser.add_argument('--model_dir', type=str, default='XXX')
parser.add_argument('--save_dir', type=str, default='results/')
parser.add_argument('--lamb', type=float, default=5.0)
parser.add_argument('--kappa', type=float, default=0.0)
parser.add_argument('--freeze', type=float, default=0.0)
parser.add_argument('--dataset_category', type=str, default="color")
parser.add_argument('--dataset_path', type=str, default="../T2I-CompBench-main")
parser.add_argument('--dtype', type=str, default='fp32')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
dtype = None
if args.dtype in ['fp32']:
dtype = torch.float32
elif args.dtype in ['fp64']:
dtype = torch.float64
elif args.dtype in ['fp16']:
dtype = torch.float16
elif args.dtype in ['bf16']:
dtype = torch.bfloat16
device = args.device
start_index = args.start_index
sampler_type = args.sampler_type
test_num = args.test_num
guidance_scale = args.guidance
num_inference_steps = args.num_inference_steps
lamb = args.lamb
freeze = args.freeze
kappa = args.kappa
model_dir = args.model_dir
# load model
sd_pipe = None
if args.model in ['sd15']:
sd_pipe = StableDiffusionPipeline.from_pretrained(
model_dir,
torch_dtype=dtype, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("sd-1.5 model loaded")
elif args.model in ['sd2_base']:
sd_pipe = StableDiffusionPipeline.from_pretrained(
model_dir,
torch_dtype=dtype, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("sd-2-base model loaded")
elif args.model in ['sdxl']:
sd_pipe = DiffusionPipeline.from_pretrained(
model_dir,
torch_dtype=dtype, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("sd-xl-base model loaded")
elif args.model in ['pixart']:
sd_pipe = PixArtAlphaPipeline.from_pretrained(
model_dir,
torch_dtype=dtype, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("PixArt-XL-2-512x512 model loaded")
SAMPLER_CONFIG = {
'dpm_lm': {
'scheduler': DPMSolverMultistepLMScheduler,
'params': {'solver_order': 3, 'algorithm_type': "dpmsolver", 'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
},
'dpm': {
'scheduler': DPMSolverMultistepLMScheduler,
'params': {'solver_order': 3, 'algorithm_type': "dpmsolver", 'lm': False}
},
'dpm++': {
'scheduler': DPMSolverMultistepLMScheduler,
'params': {'solver_order': 3, 'algorithm_type': "dpmsolver++", 'lm': False}
},
'dpm++_lm': {
'scheduler': DPMSolverMultistepLMScheduler,
'params': {'solver_order': 3, 'algorithm_type': "dpmsolver++", 'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
},
'pndm': {'scheduler': PNDMScheduler, 'params': {}},
'ddim': {'scheduler': DDIMScheduler, 'params': {}},
'ddim_lm': {
'scheduler': DDIMLMScheduler,
'params': {'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
},
'unipc': {'scheduler': UniPCMultistepScheduler, 'params': {}},
}
if sampler_type in SAMPLER_CONFIG:
config = SAMPLER_CONFIG[sampler_type]
scheduler_class = config['scheduler']
sd_pipe.scheduler = scheduler_class.from_config(sd_pipe.scheduler.config)
for param, value in config['params'].items():
if hasattr(sd_pipe.scheduler, param):
setattr(sd_pipe.scheduler, param, value)
elif hasattr(sd_pipe.scheduler.config, param):
setattr(sd_pipe.scheduler.config, param, value)
else:
raise ValueError(f"invalid: '{sampler_type}'.")
save_dir = args.save_dir
if sampler_type in ['ddim_lm', 'dpm++_lm', 'dpm_lm']:
save_dir = os.path.join(save_dir, args.model, args.dataset_category, sampler_type + "_lambda_" + str(lamb))
else:
save_dir = os.path.join(save_dir, args.model, args.dataset_category, sampler_type)
save_dir = os.path.join(save_dir, "samples")
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
def getT2IDataset(file_path):
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
stripped_line = line.strip()
if stripped_line:
yield stripped_line
# T2I prompts
dataset_path = os.path.join(args.dataset_path, 'examples/dataset', args.dataset_category + '_val.txt')
count = 0
with tqdm(total=300 * test_num, desc="Generating Images") as pbar:
try:
for prompt in getT2IDataset(dataset_path):
for seed in range(start_index, start_index + test_num):
torch.manual_seed(seed)
res = sd_pipe(prompt=prompt, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, generator=None).images[0]
res.save(os.path.join(save_dir, f"{prompt}_{count:06d}.png"))
count += 1
pbar.update(1)
except FileNotFoundError:
print(f"dataset can not be found: {dataset_path}")
except Exception as e:
print(f"unknown error: {str(e)}")
print(f"{dataset_path} finish")
if __name__ == '__main__':
main() |