T2I-Adapter / test_adapter.py
Adapter's picture
support composable adapter (#5)
b3478e4
import os
import cv2
import torch
from basicsr.utils import tensor2img
from pytorch_lightning import seed_everything
from torch import autocast
from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
from ldm.modules.extra_condition import api
from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
torch.set_grad_enabled(False)
def main():
supported_cond = [e.name for e in ExtraCondition]
parser = get_base_argument_parser()
parser.add_argument(
'--which_cond',
type=str,
required=True,
choices=supported_cond,
help='which condition modality you want to test',
)
opt = parser.parse_args()
which_cond = opt.which_cond
if opt.outdir is None:
opt.outdir = f'outputs/test-{which_cond}'
os.makedirs(opt.outdir, exist_ok=True)
if opt.resize_short_edge is None:
print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# support two test mode: single image test, and batch test (through a txt file)
if opt.prompt.endswith('.txt'):
assert opt.prompt.endswith('.txt')
image_paths = []
prompts = []
with open(opt.prompt, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
image_paths.append(line.split('; ')[0])
prompts.append(line.split('; ')[1])
else:
image_paths = [opt.cond_path]
prompts = [opt.prompt]
print(image_paths)
# prepare models
sd_model, sampler = get_sd_models(opt)
adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
cond_model = None
if opt.cond_inp_type == 'image':
cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))
process_cond_module = getattr(api, f'get_cond_{which_cond}')
# inference
with torch.inference_mode(), \
sd_model.ema_scope(), \
autocast('cuda'):
for test_idx, (cond_path, prompt) in enumerate(zip(image_paths, prompts)):
seed_everything(opt.seed)
for v_idx in range(opt.n_samples):
# seed_everything(opt.seed+v_idx+test_idx)
cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)
base_count = len(os.listdir(opt.outdir)) // 2
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
adapter_features, append_to_context = get_adapter_feature(cond, adapter)
opt.prompt = prompt
result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))
if __name__ == '__main__':
main()