|
|
from generator import Generator |
|
|
import json |
|
|
import os |
|
|
import torch |
|
|
import gc |
|
|
from utils.pipelines import * |
|
|
import argparse |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="生成图片") |
|
|
parser.add_argument( |
|
|
"--json_path", |
|
|
type=str, |
|
|
help="json路径", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--out_dir", |
|
|
type=str, |
|
|
help="输出目录", |
|
|
) |
|
|
parser.add_argument("--num_devices", type=int, default=8, help="设备数量") |
|
|
parser.add_argument("--batch_size", type=int, default=1, help="批量大小") |
|
|
parser.add_argument("--num_machine", type=int, default=1, help="机器数量") |
|
|
parser.add_argument("--machine_id", type=int, default=0, help="机器id") |
|
|
parser.add_argument( |
|
|
"--pipeline_name", type=str, nargs="+", default=None, help="pipeline名称" |
|
|
) |
|
|
parser.add_argument("--enable_availabel_check", action="store_true") |
|
|
parser.add_argument("--reverse", action="store_true") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
num_devices = args.num_devices |
|
|
pipeline_params = [globals()[f"{name}_pipe"] for name in args.pipeline_name] |
|
|
|
|
|
if args.reverse: |
|
|
pipeline_params = pipeline_params[::-1] |
|
|
|
|
|
|
|
|
if args.enable_availabel_check: |
|
|
print(f"Checking {len(pipeline_params)} pipelines") |
|
|
for pipeline_param in pipeline_params: |
|
|
generator = Generator( |
|
|
pipe_name=pipeline_param.pipeline_name, |
|
|
pipe_type=pipeline_param.pipeline_type, |
|
|
pipe_init_kwargs=pipeline_param.pipe_init_kwargs, |
|
|
num_devices=num_devices, |
|
|
) |
|
|
|
|
|
with open(args.json_path, "r") as f: |
|
|
entries = json.load(f) |
|
|
info_dict = entries[: args.batch_size] |
|
|
generator.generate( |
|
|
info_dict, |
|
|
os.path.join(args.out_dir, pipeline_param.generation_path), |
|
|
batch_size=args.batch_size, |
|
|
num_processes=num_devices, |
|
|
seed=42, |
|
|
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"], |
|
|
generation_kwargs=pipeline_param.generation_kwargs, |
|
|
base_resolution=pipeline_param.base_resolution, |
|
|
force_aspect_ratio=pipeline_param.force_aspect_ratio, |
|
|
) |
|
|
del generator |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
print(f"Finished Checking {pipeline_param.pipeline_name}") |
|
|
|
|
|
for pipeline_param in pipeline_params: |
|
|
generator = Generator( |
|
|
pipe_name=pipeline_param.pipeline_name, |
|
|
pipe_type=pipeline_param.pipeline_type, |
|
|
pipe_init_kwargs=pipeline_param.pipe_init_kwargs, |
|
|
num_devices=num_devices, |
|
|
) |
|
|
|
|
|
with open(args.json_path, "r") as f: |
|
|
entries = json.load(f) |
|
|
|
|
|
for i in range(args.num_machine): |
|
|
start_idx = i * len(entries) // args.num_machine |
|
|
end_idx = ( |
|
|
(i + 1) * len(entries) // args.num_machine |
|
|
if i != args.num_machine - 1 |
|
|
else len(entries) |
|
|
) |
|
|
if i == args.machine_id: |
|
|
info_dict = entries[start_idx:end_idx] |
|
|
|
|
|
info_dict = sorted(info_dict, key=lambda x: x["aspect_ratio"]) |
|
|
|
|
|
print(f"Generating {len(info_dict)} images") |
|
|
generator.generate( |
|
|
info_dict, |
|
|
os.path.join(args.out_dir, pipeline_param.generation_path), |
|
|
batch_size=args.batch_size, |
|
|
num_processes=num_devices, |
|
|
seed=42, |
|
|
weight_dtype=pipeline_param.pipe_init_kwargs["torch_dtype"], |
|
|
generation_kwargs=pipeline_param.generation_kwargs, |
|
|
base_resolution=pipeline_param.base_resolution, |
|
|
force_aspect_ratio=pipeline_param.force_aspect_ratio, |
|
|
) |
|
|
|
|
|
print(f"Finished generating {pipeline_param.pipeline_name}") |
|
|
|
|
|
for pipeline in generator.pipelines: |
|
|
pipeline.to("cpu") |
|
|
del generator |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|