import einops import gradio as gr import numpy as np import torch import random import os import subprocess import shlex from huggingface_hub import hf_hub_url, hf_hub_download from share import * from pytorch_lightning import seed_everything from annotator.util import resize_image, HWC3 from annotator.OneFormer import OneformerSegmenter from cldm.model import create_model, load_state_dict from cldm.ddim_hacked import DDIMSamplerSpaCFG from ldm.models.autoencoder import DiagonalGaussianDistribution urls = { 'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'], 'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['pair_diffusion_epoch62.ckpt'] } WTS_DICT = { } if os.path.exists('checkpoints') == False: os.mkdir('checkpoints') for repo in urls: files = urls[repo] for file in files: url = hf_hub_url(repo, file) name_ckp = url.split('/')[-1] save_path = os.path.join('checkpoints', name_ckp) if os.path.exists(save_path) == False: WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file, token=os.environ.get("ACCESS_TOKEN")) print(WTS_DICT) apply_segmentor = OneformerSegmenter(WTS_DICT['shi-labs/oneformer_coco_swin_large']) model = create_model('./configs/sap_fixed_hintnet_v15.yaml').cpu() model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda')) model = model.cuda() ddim_sampler = DDIMSamplerSpaCFG(model) _COLORS = [] save_memory = False def gen_color(): color = tuple(np.round(np.random.choice(range(256), size=3), 3)) if color not in _COLORS and np.mean(color) != 0.0: _COLORS.append(color) else: gen_color() for _ in range(300): gen_color() class ImageComp: def __init__(self, edit_operation): self.input_img = None self.input_pmask = None self.input_segmask = None self.ref_img = None self.ref_pmask = None self.ref_segmask = None self.H = None self.W = None self.baseoutput = None self.kernel = np.ones((5, 5), np.uint8) self.edit_operation = edit_operation def init_input_canvas(self, img): img = HWC3(img) img = resize_image(img, 512) detected_mask = apply_segmentor(img, 'panoptic')[0] detected_seg = apply_segmentor(img, 'semantic') self.input_img = img self.input_pmask = detected_mask self.input_segmask = detected_seg self.H = img.shape[0] self.W = img.shape[1] detected_mask = detected_mask.cpu().numpy() uni = np.unique(detected_mask) color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3)) for i in uni: color_mask[detected_mask == i] = _COLORS[i] output = color_mask*0.8 + img * 0.2 self.baseoutput = output.astype(np.uint8) return self.baseoutput def init_ref_canvas(self, img): img = HWC3(img) img = resize_image(img, 512) detected_mask = apply_segmentor(img, 'panoptic')[0] detected_seg = apply_segmentor(img, 'semantic') self.ref_img = img self.ref_pmask = detected_mask self.ref_segmask = detected_seg detected_mask = detected_mask.cpu().numpy() uni = np.unique(detected_mask) color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3)) for i in uni: color_mask[detected_mask == i] = _COLORS[i] output = color_mask*0.8 + img * 0.2 self.baseoutput = output.astype(np.uint8) return self.baseoutput def process_mask(self, mask, panoptic_mask, segmask): panoptic_mask_ = panoptic_mask + 1 mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape)) mask_ = torch.tensor(mask_) maski = torch.zeros_like(mask_).cuda() maski[mask_ > 127] = 1 mask = maski * panoptic_mask_ unique_ids, counts = torch.unique(mask, return_counts=True) # print(unique_ids, counts) mask_id = unique_ids[torch.argmax(counts[1:]) + 1] final_mask = torch.zeros(mask.shape).cuda() final_mask[panoptic_mask_ == mask_id] = 1 obj_class = maski * (segmask + 1) unique_ids, counts = torch.unique(obj_class, return_counts=True) obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1 return final_mask, obj_class def edit_app(self, input_mask, ref_mask, whole_ref): input_pmask = self.input_pmask input_segmask = self.input_segmask if whole_ref: reference_mask = torch.ones(self.ref_pmask.shape).cuda() else: reference_mask, _ = self.process_mask(ref_mask, self.ref_pmask, self.ref_segmask) edit_mask, _ = self.process_mask(input_mask, self.input_pmask, self.input_segmask) ma = torch.max(input_pmask) input_pmask[edit_mask == 1] = ma + 1 return reference_mask, input_pmask, input_segmask, edit_mask, ma def edit(self, input_mask, ref_mask, whole_ref=False, inter=1): input_img = (self.input_img/127.5 - 1) input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) reference_img = (self.ref_img/127.5 - 1) reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) reference_mask, input_pmask, input_segmask, region_mask, ma = self.edit_app(input_mask, ref_mask, whole_ref) input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) _, mean_feat_inpt, one_hot_inpt, empty_mask_flag_inpt = model.get_appearance(input_img, input_pmask, return_all=True) reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1) _, mean_feat_ref, _, _ = model.get_appearance(reference_img, reference_mask, return_all=True) if mean_feat_ref.shape[1] > 1: mean_feat_inpt[:, ma + 1] = (1 - inter) * mean_feat_inpt[:, ma + 1] + inter*mean_feat_ref[:, 1] splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt, one_hot_inpt) appearance = torch.nn.functional.normalize(splatted_feat) #l2 normaliz input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1) structure = torch.nn.functional.interpolate(input_segmask, (self.H, self.W)) appearance = torch.nn.functional.interpolate(appearance, (self.H, self.W)) return structure, appearance, region_mask, input_img def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength, scale_s, scale_f, scale_t, seed, eta, masking=True,whole_ref=False,inter=1): structure, appearance, mask, img = self.edit(input_mask, ref_mask, whole_ref=whole_ref, inter=inter) null_structure = torch.zeros(structure.shape).cuda() - 1 null_appearance = torch.zeros(appearance.shape).cuda() null_control = torch.cat([null_structure, null_appearance], dim=1) structure_control = torch.cat([structure, null_appearance], dim=1) full_control = torch.cat([structure, appearance], dim=1) null_control = torch.cat([null_control for _ in range(num_samples)], dim=0) structure_control = torch.cat([structure_control for _ in range(num_samples)], dim=0) full_control = torch.cat([full_control for _ in range(num_samples)], dim=0) #Masking for local edit if not masking: mask, x0 = None, None else: x0 = model.encode_first_stage(img) x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number x0 = x0 * model.scale_factor mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda() mask = torch.nn.functional.interpolate(mask, x0.shape[2:]).float() if seed == -1: seed = random.randint(0, 65535) seed_everything(seed) scale = [scale_s, scale_f, scale_t] print(scale) if save_memory: model.low_vram_shift(is_diffusing=False) # uc_cross = model.get_unconditional_conditioning(num_samples) uc_cross = model.get_learned_conditioning([n_prompt] * num_samples) cond = {"c_concat": [full_control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]} un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]} un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]} shape = (4, self.H // 8, self.W // 8) if save_memory: model.low_vram_shift(is_diffusing=True) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 samples, _ = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, mask=mask, x0=x0, unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ]) if save_memory: model.low_vram_shift(is_diffusing=False) x_samples = (model.decode_first_stage(samples) + 1) * 127.5 x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] return [] + results css = """ h1 { text-align: center; } """ def create_app_demo(): with gr.Row(): gr.Markdown("## Object Level Appearance Editing") print('first row') with gr.Row(): gr.HTML( """