Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,989 Bytes
cb0e352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
"""
This is the script of scaled cross-reference customization inference with Calligrapher.
"""
import os
import json
import random
from PIL import Image
import numpy as np
from datetime import datetime
import torch
from diffusers.utils import load_image
from pipeline_calligrapher import CalligrapherPipeline
from models.calligrapher import Calligrapher
from models.transformer_flux_inpainting import FluxTransformer2DModel
from utils import resize_img_and_pad, generate_context_reference_image
def infer_calligrapher(test_image_dir, result_save_dir,
target_h=512, target_w=512,
gen_num_per_case=2):
# Set job dir.
job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
result_save_path = os.path.join(result_save_dir, job_name)
if not os.path.exists(result_save_path):
os.makedirs(result_save_path, exist_ok=True)
# Load models.
base_model_path = path_dict['base_model_path']
image_encoder_path = path_dict['image_encoder_path']
calligrapher_path = path_dict['calligrapher_path']
transformer = FluxTransformer2DModel.from_pretrained(
base_model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = CalligrapherPipeline.from_pretrained(base_model_path,
transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
model = Calligrapher(pipe, image_encoder_path, calligrapher_path,
device="cuda", num_tokens=128)
source_image_names = [i for i in os.listdir(test_image_dir) if 'source.png' in i]
# Loading prompts from the bench txt and printing them.
info_dict = {}
with open(os.path.join(test_image_dir, 'cross_bench.txt'), 'r') as file:
for line in file:
line = line.strip()
if line:
key, value = line.split('-', 1)
info_dict[int(key)] = value
print(info_dict)
i = 0
print('Printing given prompts...')
for img_id in sorted(info_dict.keys()):
i += 1
info = info_dict[img_id]
print(f'Sample #{i}: {img_id} - {info}')
count = 0
for source_image_name in sorted(source_image_names):
count += 1
img_id = int(source_image_name.split("test")[1].split("_")[0])
if img_id not in info_dict.keys():
continue
info = info_dict[img_id]
ref_ids, text = info.split('-')
ref_ids = ref_ids.split(',')
prompt = f"The text is '{text}'."
source_image_path = os.path.join(test_image_dir, source_image_name)
mask_image_name = source_image_name.replace('source', 'mask')
mask_image_path = os.path.join(test_image_dir, mask_image_name)
for ref_id in ref_ids:
reference_image_name = source_image_name.replace('source', 'ref').replace(f'{img_id}', f'{ref_id}')
reference_image_path = os.path.join(test_image_dir, reference_image_name)
print('source_image_path: ', source_image_path)
print('mask_image_path: ', mask_image_path)
print('reference_image_path: ', reference_image_path)
print(f'prompt: {prompt}')
source_image = load_image(source_image_path)
mask_image = load_image(mask_image_path)
# Resize source and mask.
source_image = source_image.resize((target_w, target_h))
mask_image = mask_image.resize((target_w, target_h), Image.NEAREST)
mask_np = np.array(mask_image)
mask_np[mask_np > 0] = 255
mask_image = Image.fromarray(mask_np.astype(np.uint8))
source_img_w, source_img_h = source_image.size
# resize reference to fit CLIP.
reference_image = Image.open(reference_image_path).convert("RGB")
reference_image_to_encoder = resize_img_and_pad(reference_image, target_size=[512, 512])
reference_context = generate_context_reference_image(reference_image, source_img_w)
# Concat the context image on the top.
source_with_context = Image.new(source_image.mode, (source_img_w, reference_context.size[1] + source_img_h))
source_with_context.paste(reference_context, (0, 0))
source_with_context.paste(source_image, (0, reference_context.size[1]))
# Concat the 0 mask on the top of the mask image.
mask_with_context = Image.new(mask_image.mode,
(mask_image.size[0], reference_context.size[1] + mask_image.size[0]), color=0)
mask_with_context.paste(mask_image, (0, reference_context.size[1]))
# Identifiers in filename.
ref_id = reference_image_name.split('_')[0]
safe_prompt = prompt.replace(" ", "_").replace("'", "").replace(",", "").replace('"', '').replace('?', '')[:50]
for i in range(gen_num_per_case):
seed = random.randint(0, 2 ** 32 - 1)
images = model.generate(
image=source_with_context,
mask_image=mask_with_context,
ref_image=reference_image_to_encoder,
prompt=prompt,
scale=1.0,
num_inference_steps=50,
width=source_with_context.size[0],
height=source_with_context.size[1],
seed=seed,
)
index = len(os.listdir(result_save_path))
output_filename = f"result_{index}_{ref_id}_{safe_prompt}_{seed}.png"
result_img = images[0]
result_img_vis = result_img.crop((0, reference_context.size[1], result_img.width, result_img.height))
result_img_vis.save(os.path.join(result_save_path, output_filename))
target_size = (source_image.size[0], source_image.size[1])
vis_img = Image.new('RGB', (source_image.size[0] * 3, source_image.size[1]))
vis_img.paste(source_image.resize(target_size), (0, 0))
vis_img.paste(reference_context.resize(target_size), (source_image.size[0], 0))
vis_img.paste(result_img_vis.resize(target_size), (source_image.size[0] * 2, 0))
vis_img_save_path = os.path.join(result_save_path, f'vis_{output_filename}'.replace('.png', '.jpg'))
vis_img.save(vis_img_save_path)
print(f"Generated images saved to {vis_img_save_path}.")
if __name__ == '__main__':
with open(os.path.join(os.path.dirname(__file__), 'path_dict.json'), 'r') as f:
path_dict = json.load(f)
# Set directory paths.
test_image_dir = path_dict['data_dir']
result_save_dir = path_dict['cli_save_dir']
infer_calligrapher(test_image_dir, result_save_dir,
target_h=512, target_w=512,
gen_num_per_case=2)
print('Finished!')
|