File size: 6,563 Bytes
ab2369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import sys

import torch
import os
import json
import argparse

sys.path.append(os.getcwd())

from diffusers import StableDiffusionPipeline, PNDMScheduler, UniPCMultistepScheduler, DDIMScheduler, DiffusionPipeline, PixArtAlphaPipeline
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler

from tqdm import tqdm

def main():
    parser = argparse.ArgumentParser(description="sampling script for T2I-Bench.")
    parser.add_argument('--test_num', type=int, default=10)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--num_inference_steps', type=int, default=10)
    parser.add_argument('--guidance', type=float, default=7.5)
    parser.add_argument('--sampler_type', type = str, default='dpm_lm')
    parser.add_argument('--model', type=str, default='sd15', choices=['sd15', 'sd2_base', 'sdxl', 'pixart'])
    parser.add_argument('--model_dir', type=str, default='XXX')
    parser.add_argument('--save_dir', type=str, default='results/')
    parser.add_argument('--lamb', type=float, default=5.0)
    parser.add_argument('--kappa', type=float, default=0.0)
    parser.add_argument('--freeze', type=float, default=0.0)
    parser.add_argument('--dataset_category', type=str, default="color")
    parser.add_argument('--dataset_path', type=str, default="../T2I-CompBench-main")
    parser.add_argument('--dtype', type=str, default='fp32')
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()
    dtype = None
    if args.dtype in ['fp32']:
        dtype = torch.float32
    elif args.dtype in ['fp64']:
        dtype = torch.float64
    elif args.dtype in ['fp16']:
        dtype = torch.float16
    elif args.dtype in ['bf16']:
        dtype = torch.bfloat16

    device = args.device
    start_index = args.start_index
    sampler_type = args.sampler_type
    test_num = args.test_num
    guidance_scale = args.guidance
    num_inference_steps = args.num_inference_steps
    lamb = args.lamb
    freeze = args.freeze
    kappa = args.kappa
    model_dir = args.model_dir

    # load model
    sd_pipe = None
    if args.model in ['sd15']:
        sd_pipe = StableDiffusionPipeline.from_pretrained(
            model_dir,
            torch_dtype=dtype, safety_checker=None)
        sd_pipe = sd_pipe.to(device)
        print("sd-1.5 model loaded")
    elif args.model in ['sd2_base']:
        sd_pipe = StableDiffusionPipeline.from_pretrained(
            model_dir,
            torch_dtype=dtype, safety_checker=None)
        sd_pipe = sd_pipe.to(device)
        print("sd-2-base model loaded")
    elif args.model in ['sdxl']:
        sd_pipe = DiffusionPipeline.from_pretrained(
            model_dir,
            torch_dtype=dtype, safety_checker=None)
        sd_pipe = sd_pipe.to(device)
        print("sd-xl-base model loaded")
    elif args.model in ['pixart']:
        sd_pipe = PixArtAlphaPipeline.from_pretrained(
            model_dir,
            torch_dtype=dtype, safety_checker=None)
        sd_pipe = sd_pipe.to(device)
        print("PixArt-XL-2-512x512 model loaded")

    SAMPLER_CONFIG = {
        'dpm_lm': {
            'scheduler': DPMSolverMultistepLMScheduler,
            'params': {'solver_order': 3, 'algorithm_type': "dpmsolver", 'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
        },
        'dpm': {
            'scheduler': DPMSolverMultistepLMScheduler,
            'params': {'solver_order': 3, 'algorithm_type': "dpmsolver", 'lm': False}
        },
        'dpm++': {
            'scheduler': DPMSolverMultistepLMScheduler,
            'params': {'solver_order': 3, 'algorithm_type': "dpmsolver++", 'lm': False}
        },
        'dpm++_lm': {
            'scheduler': DPMSolverMultistepLMScheduler,
            'params': {'solver_order': 3, 'algorithm_type': "dpmsolver++", 'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
        },
        'pndm': {'scheduler': PNDMScheduler, 'params': {}},
        'ddim': {'scheduler': DDIMScheduler, 'params': {}},
        'ddim_lm': {
            'scheduler': DDIMLMScheduler,
            'params': {'lm': True, 'lamb': lamb, 'kappa': kappa, 'freeze': freeze}
        },
        'unipc': {'scheduler': UniPCMultistepScheduler, 'params': {}},
    }

    if sampler_type in SAMPLER_CONFIG:
        config = SAMPLER_CONFIG[sampler_type]
        scheduler_class = config['scheduler']
        sd_pipe.scheduler = scheduler_class.from_config(sd_pipe.scheduler.config)
        
        for param, value in config['params'].items():
            if hasattr(sd_pipe.scheduler, param):
                setattr(sd_pipe.scheduler, param, value)
            elif hasattr(sd_pipe.scheduler.config, param):
                 setattr(sd_pipe.scheduler.config, param, value)
    else:
        raise ValueError(f"invalid: '{sampler_type}'.")
        
    save_dir = args.save_dir
    
    if sampler_type in ['ddim_lm', 'dpm++_lm', 'dpm_lm']:
        save_dir = os.path.join(save_dir, args.model, args.dataset_category, sampler_type + "_lambda_" + str(lamb))
    else:
        save_dir = os.path.join(save_dir, args.model, args.dataset_category, sampler_type)
    save_dir = os.path.join(save_dir, "samples")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    
    def getT2IDataset(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            for line in file:
                stripped_line = line.strip()
                if stripped_line:
                    yield stripped_line
    
    # T2I prompts
    dataset_path = os.path.join(args.dataset_path, 'examples/dataset', args.dataset_category + '_val.txt')
    count = 0
    with tqdm(total=300 * test_num, desc="Generating Images") as pbar:
        try:
            for prompt in getT2IDataset(dataset_path):
                for seed in range(start_index, start_index + test_num):
                    torch.manual_seed(seed)
                    res = sd_pipe(prompt=prompt, num_inference_steps=num_inference_steps,
                            guidance_scale=guidance_scale, generator=None).images[0]
                    res.save(os.path.join(save_dir, f"{prompt}_{count:06d}.png"))
                    count += 1
                    pbar.update(1)
        except FileNotFoundError:
            print(f"dataset can not be found: {dataset_path}")
        except Exception as e:
            print(f"unknown error: {str(e)}")
    print(f"{dataset_path} finish")

if __name__ == '__main__':
    main()