T2I-Adapter / test_composable_adapters.py
Adapter's picture
support composable adapter (#5)
b3478e4
import cv2
import os
import torch
from pytorch_lightning import seed_everything
from torch import autocast
from basicsr.utils import tensor2img
from ldm.inference_base import diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models
from ldm.modules.extra_condition import api
from ldm.modules.extra_condition.api import ExtraCondition, get_adapter_feature, get_cond_model
torch.set_grad_enabled(False)
def main():
supported_cond = [e.name for e in ExtraCondition]
parser = get_base_argument_parser()
for cond_name in supported_cond:
parser.add_argument(
f'--{cond_name}_path',
type=str,
default=None,
help=f'condition image path for {cond_name}',
)
parser.add_argument(
f'--{cond_name}_inp_type',
type=str,
default='image',
help=f'the type of the input condition image, can be image or {cond_name}',
choices=['image', cond_name],
)
parser.add_argument(
f'--{cond_name}_adapter_ckpt',
type=str,
default=None,
help=f'path to checkpoint of the {cond_name} adapter, '
f'if {cond_name}_path is not None, this should not be None too',
)
parser.add_argument(
f'--{cond_name}_weight',
type=float,
default=1.0,
help=f'the {cond_name} adapter features are multiplied by the {cond_name}_weight and then summed up together',
)
opt = parser.parse_args()
# process argument
activated_conds = []
cond_paths = []
adapter_ckpts = []
for cond_name in supported_cond:
if getattr(opt, f'{cond_name}_path') is None:
continue
assert getattr(opt, f'{cond_name}_adapter_ckpt') is not None, f'you should specify the {cond_name}_adapter_ckpt'
activated_conds.append(cond_name)
cond_paths.append(getattr(opt, f'{cond_name}_path'))
adapter_ckpts.append(getattr(opt, f'{cond_name}_adapter_ckpt'))
assert len(activated_conds) != 0, 'you did not input any condition'
if opt.outdir is None:
opt.outdir = f'outputs/test-composable-adapters'
os.makedirs(opt.outdir, exist_ok=True)
if opt.resize_short_edge is None:
print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# prepare models
adapters = []
cond_models = []
cond_inp_types = []
process_cond_modules = []
for cond_name in activated_conds:
adapters.append(get_adapters(opt, getattr(ExtraCondition, cond_name)))
cond_inp_type = getattr(opt, f'{cond_name}_inp_type', 'image')
if cond_inp_type == 'image':
cond_models.append(get_cond_model(opt, getattr(ExtraCondition, cond_name)))
else:
cond_models.append(None)
cond_inp_types.append(cond_inp_type)
process_cond_modules.append(getattr(api, f'get_cond_{cond_name}'))
sd_model, sampler = get_sd_models(opt)
# inference
with torch.inference_mode(), \
sd_model.ema_scope(), \
autocast('cuda'):
seed_everything(opt.seed)
conds = []
for cond_idx, cond_name in enumerate(activated_conds):
conds.append(process_cond_modules[cond_idx](
opt, cond_paths[cond_idx], cond_inp_types[cond_idx], cond_models[cond_idx],
))
adapter_features, append_to_context = get_adapter_feature(conds, adapters)
for v_idx in range(opt.n_samples):
result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
base_count = len(os.listdir(opt.outdir))
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))
if __name__ == '__main__':
main()