import argparse import hashlib import json import os.path import torch from diffusers import DPMSolverMultistepScheduler from diffusers.models import T2IAdapter from PIL import Image from mixofshow.pipelines.pipeline_regionally_t2iadapter import RegionallyT2IAdapterPipeline def sample_image(pipe, input_prompt, input_neg_prompt=None, generator=None, num_inference_steps=50, guidance_scale=7.5, sketch_adaptor_weight=1.0, region_sketch_adaptor_weight='', keypose_adaptor_weight=1.0, region_keypose_adaptor_weight='', **extra_kargs ): keypose_condition = extra_kargs.pop('keypose_condition') if keypose_condition is not None: keypose_adapter_input = [keypose_condition] * len(input_prompt) else: keypose_adapter_input = None sketch_condition = extra_kargs.pop('sketch_condition') if sketch_condition is not None: sketch_adapter_input = [sketch_condition] * len(input_prompt) else: sketch_adapter_input = None images = pipe( prompt=input_prompt, negative_prompt=input_neg_prompt, keypose_adapter_input=keypose_adapter_input, keypose_adaptor_weight=keypose_adaptor_weight, region_keypose_adaptor_weight=region_keypose_adaptor_weight, sketch_adapter_input=sketch_adapter_input, sketch_adaptor_weight=sketch_adaptor_weight, region_sketch_adaptor_weight=region_sketch_adaptor_weight, generator=generator, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, **extra_kargs).images return images def build_model(pretrained_model, device): pipe = RegionallyT2IAdapterPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16).to(device) assert os.path.exists(os.path.join(pretrained_model, 'new_concept_cfg.json')) with open(os.path.join(pretrained_model, 'new_concept_cfg.json'), 'r') as json_file: new_concept_cfg = json.load(json_file) pipe.set_new_concept_cfg(new_concept_cfg) pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model, subfolder='scheduler') pipe.keypose_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_openpose_sd14v1', torch_dtype=torch.float16).to(device) pipe.sketch_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_sketch_sd14v1', torch_dtype=torch.float16).to(device) return pipe def prepare_text(prompt, region_prompts, height, width): ''' Args: prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text] Returns: full_prompt: subject1, attribute1 and subject2, attribute2, global text context_prompt: subject1 and subject2, global text entity_collection: [(subject1, attribute1), Location1] ''' region_collection = [] regions = region_prompts.split('|') for region in regions: if region == '': break prompt_region, neg_prompt_region, pos = region.split('-*-') prompt_region = prompt_region.replace('[', '').replace(']', '') neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '') pos = eval(pos) if len(pos) == 0: pos = [0, 0, 1, 1] else: pos[0], pos[2] = pos[0] / height, pos[2] / height pos[1], pos[3] = pos[1] / width, pos[3] / width region_collection.append((prompt_region, neg_prompt_region, pos)) return (prompt, region_collection) def parse_args(): parser = argparse.ArgumentParser('', add_help=False) parser.add_argument('--pretrained_model', default='experiments/composed_edlora/anythingv4/hina+kario+tezuka+mitsuha+son_anythingv4/combined_model_base', type=str) parser.add_argument('--sketch_condition', default=None, type=str) parser.add_argument('--sketch_adaptor_weight', default=1.0, type=float) parser.add_argument('--region_sketch_adaptor_weight', default='', type=str) parser.add_argument('--keypose_condition', default=None, type=str) parser.add_argument('--keypose_adaptor_weight', default=1.0, type=float) parser.add_argument('--region_keypose_adaptor_weight', default='', type=str) parser.add_argument('--save_dir', default=None, type=str) parser.add_argument('--prompt', default='photo of a toy', type=str) parser.add_argument('--negative_prompt', default='', type=str) parser.add_argument('--prompt_rewrite', default='', type=str) parser.add_argument('--seed', default=16141, type=int) parser.add_argument('--suffix', default='', type=str) return parser.parse_args() if __name__ == '__main__': args = parse_args() device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') pipe = build_model(args.pretrained_model, device) if args.sketch_condition is not None and os.path.exists(args.sketch_condition): sketch_condition = Image.open(args.sketch_condition).convert('L') width_sketch, height_sketch = sketch_condition.size print('use sketch condition') else: sketch_condition, width_sketch, height_sketch = None, 0, 0 print('skip sketch condition') if args.keypose_condition is not None and os.path.exists(args.keypose_condition): keypose_condition = Image.open(args.keypose_condition).convert('RGB') width_pose, height_pose = keypose_condition.size print('use pose condition') else: keypose_condition, width_pose, height_pose = None, 0, 0 print('skip pose condition') if width_sketch != 0 and width_pose != 0: assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size' width, height = max(width_pose, width_sketch), max(height_pose, height_sketch) kwargs = { 'sketch_condition': sketch_condition, 'keypose_condition': keypose_condition, 'height': height, 'width': width, } prompts = [args.prompt] prompts_rewrite = [args.prompt_rewrite] input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)] save_prompt = input_prompt[0][0] image = sample_image( pipe, input_prompt=input_prompt, input_neg_prompt=[args.negative_prompt] * len(input_prompt), generator=torch.Generator(device).manual_seed(args.seed), sketch_adaptor_weight=args.sketch_adaptor_weight, region_sketch_adaptor_weight=args.region_sketch_adaptor_weight, keypose_adaptor_weight=args.keypose_adaptor_weight, region_keypose_adaptor_weight=args.region_keypose_adaptor_weight, **kwargs) print(f'save to: {args.save_dir}') configs = [ f'pretrained_model: {args.pretrained_model}\n', f'context_prompt: {args.prompt}\n', f'neg_context_prompt: {args.negative_prompt}\n', f'sketch_condition: {args.sketch_condition}\n', f'sketch_adaptor_weight: {args.sketch_adaptor_weight}\n', f'region_sketch_adaptor_weight: {args.region_sketch_adaptor_weight}\n', f'keypose_condition: {args.keypose_condition}\n', f'keypose_adaptor_weight: {args.keypose_adaptor_weight}\n', f'region_keypose_adaptor_weight: {args.region_keypose_adaptor_weight}\n', f'random seed: {args.seed}\n', f'prompt_rewrite: {args.prompt_rewrite}\n' ] hash_code = hashlib.sha256(''.join(configs).encode('utf-8')).hexdigest()[:8] save_prompt = save_prompt.replace(' ', '_') # save_name = f'{save_prompt}---{args.suffix}---{hash_code}.png' # save_dir = os.path.join(args.save_dir, f'seed_{args.seed}') save_name = f'{save_prompt}---{args.suffix}(seed{args.seed})---{hash_code}.png' save_dir = args.save_dir save_path = os.path.join(save_dir, save_name) save_config_path = os.path.join(save_dir, save_name.replace('.png', '.txt')) os.makedirs(save_dir, exist_ok=True) image[0].save(os.path.join(save_dir, save_name)) with open(save_config_path, 'w') as fw: fw.writelines(configs)