|
import torch |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler |
|
from diffusion_module.utils.Pipline import SDMLDMPipeline |
|
|
|
def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld, |
|
resolution=512,g_step=2,save_dir="cityspace_test"): |
|
scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config) |
|
pipeline = SDMLDMPipeline( |
|
vae=accelerator.unwrap_model(vae), |
|
unet=accelerator.unwrap_model(unet), |
|
scheduler=scheduler, |
|
torch_dtype=weight_dtype, |
|
resolution = resolution, |
|
resolution_type="crack" |
|
) |
|
|
|
pipeline = pipeline.to(accelerator.device) |
|
pipeline.set_progress_bar_config(disable=False) |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
|
|
generator = None |
|
for i ,batch in enumerate(data_ld): |
|
if i > 2: |
|
break |
|
images = [] |
|
with torch.autocast("cuda"): |
|
segmap = preprocess_input(batch[1]['label'], num_classes=151) |
|
segmap = segmap.to("cuda").to(torch.float16) |
|
|
|
|
|
|
|
image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1, |
|
num_inference_steps=50, s=1.5).images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
images.extend(image) |
|
merge_images(images, i,accelerator,g_step) |
|
del pipeline |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def merge_images(images, val_step,accelerator,step): |
|
for k, image in enumerate(images): |
|
""" |
|
if k == 0: |
|
filename = "{}_condition.png".format(val_step) |
|
else: |
|
filename = "{}_{}.png".format(val_step, k) |
|
""" |
|
filename = "{}_{}.png".format(val_step, k) |
|
|
|
path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename) |
|
os.makedirs(os.path.split(path)[0], exist_ok=True) |
|
|
|
image.save(path) |
|
|
|
|
|
total_width = sum(img.width for img in images) |
|
max_height = max(img.height for img in images) |
|
combined_image = Image.new('RGB', (total_width, max_height)) |
|
|
|
|
|
x_offset = 0 |
|
for img in images: |
|
|
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
combined_image.paste(img, (x_offset, 0)) |
|
x_offset += img.width |
|
|
|
|
|
merge_filename = "{}_merge.png".format(val_step) |
|
merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename) |
|
os.makedirs(os.path.split(merge_path)[0], exist_ok=True) |
|
combined_image.save(merge_path) |
|
|
|
def preprocess_input(data, num_classes): |
|
|
|
data = data.to(dtype=torch.int64) |
|
|
|
|
|
label_map = data |
|
bs, _, h, w = label_map.size() |
|
input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device) |
|
input_semantics = input_label.scatter_(1, label_map, 1.0) |
|
|
|
return input_semantics |
|
|
|
|