Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from diffusers import DDIMScheduler | |
import cv2 | |
from utils.sdxl import sdxl | |
from utils.inversion import Inversion | |
import math | |
import torch.nn.functional as F | |
import utils.utils as utils | |
import os | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageDraw, ImageFont | |
import spaces | |
MAX_NUM_WORDS = 77 | |
class LayerFusion: | |
def get_mask(self, maps, alpha, use_pool,x_t): | |
k = 1 | |
maps = (maps * alpha).sum(-1).mean(1) | |
if use_pool: | |
maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) | |
mask = F.interpolate(maps, size=(x_t.shape[2:])) #[2, 1, 128, 128] | |
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] | |
mask=(mask - mask.min ()) / (mask.max () - mask.min ()) | |
mask = mask.gt(self.mask_threshold) | |
self.mask=mask | |
mask = mask[:1] + mask | |
return mask | |
def get_one_mask(self, maps, use_pool, x_t, idx_lst, i=None, sav_img=False): | |
k=1 | |
if sav_img is False: | |
mask_tot = 0 | |
for obj in idx_lst: | |
mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32) | |
if use_pool: | |
mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) | |
mask = F.interpolate(mask, size=(x_t.shape[2:])) | |
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] | |
mask=(mask - mask.min ()) / (mask.max () - mask.min ()) | |
mask = mask.gt(self.mask_threshold[int(self.counter/10)]) | |
mask_tot |= mask | |
mask = mask_tot | |
return mask | |
else: | |
for obj in idx_lst: | |
mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32) | |
if use_pool: | |
mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) | |
mask = F.interpolate(mask, size=(1024, 1024))#[1, 1, 1024, 1024] | |
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] | |
mask=(mask - mask.min ()) / (mask.max () - mask.min ()) | |
mask = mask.gt(0.6) | |
mask = np.array(mask[0][0].clone().cpu()).astype(np.uint8)*255 | |
cv2.imwrite(f'./img/sam_mask/{self.blend_list[i][0]}_{self.counter}.jpg', mask) | |
return mask | |
def mv_op(self, mp, op, scale=0.2, ones=False, flip=None): | |
_, b, H, W = mp.shape | |
if ones == False: | |
new_mp = torch.zeros_like(mp) | |
else: | |
new_mp = torch.ones_like(mp) | |
K = int(scale*W) | |
if op == 'right': | |
new_mp[:, :, :, K:] = mp[:, :, :, 0:W-K] | |
elif op == 'left': | |
new_mp[:, :, :, 0:W-K] = mp[:, :, :, K:] | |
elif op == 'down': | |
new_mp[:, :, K:, :] = mp[:, :, 0:W-K, :] | |
elif op == 'up': | |
new_mp[:, :, 0:W-K, :] = mp[:, :, K:, :] | |
if flip is not None: | |
new_mp = torch.flip(new_mp, dims=flip) | |
return new_mp | |
def mv_layer(self, x_t, bg_id, fg_id, op_id): | |
bg_img = x_t[bg_id:(bg_id+1)].clone() | |
fg_img = x_t[fg_id:(fg_id+1)].clone() | |
fg_mask = self.fg_mask_list[fg_id-3] | |
op_list = self.op_list[fg_id-3] | |
for item in op_list: | |
op, scale = item[0], item[1] | |
if scale != 0: | |
fg_img = self.mv_op(fg_img, op=op, scale=scale) | |
fg_mask = self.mv_op(fg_mask, op=op, scale=scale) | |
x_t[op_id:(op_id+1)] = bg_img*(1-fg_mask) + fg_img*fg_mask | |
def __call__(self, x_t): | |
self.counter += 1 | |
# inpainting | |
if self.blend_time[0] <= self.counter <= self.blend_time[1]: | |
x_t[1:2] = x_t[1:2]*self.remove_mask + x_t[0:1]*(1-self.remove_mask) | |
if self.counter == self.blend_time[1] + 1 and self.mode != "removal": | |
b = x_t.shape[0] | |
bg_id = 1 #bg_layer | |
op_id = 2 #canvas | |
for fg_id in range(3, b): #fg_layer | |
self.mv_layer(x_t, bg_id=bg_id, fg_id=fg_id, op_id=op_id) | |
bg_id = op_id | |
return x_t | |
def __init__(self, remove_mask, fg_mask_list, refine_mask=None, | |
blend_time=[0, 40], | |
mode="removal", op_list=None): | |
self.counter = 0 | |
self.mode = mode | |
self.op_list = op_list | |
self.blend_time = blend_time | |
self.remove_mask = remove_mask | |
self.refine_mask = refine_mask | |
if self.refine_mask is not None: | |
self.new_mask = self.remove_mask + self.refine_mask | |
self.new_mask[self.new_mask>0] = 1 | |
else: | |
self.new_mask = None | |
self.fg_mask_list = fg_mask_list | |
class Control(): | |
def step_callback(self, x_t): | |
if self.layer_fusion is not None: | |
x_t = self.layer_fusion(x_t) | |
return x_t | |
def __init__(self, layer_fusion): | |
self.layer_fusion = layer_fusion | |
def register_attention_control(model, controller, mask_time=[0, 40], refine_time=[0, 25]): | |
def ca_forward(self, place_in_unet): | |
to_out = self.to_out | |
if type(to_out) is torch.nn.modules.container.ModuleList: | |
to_out = self.to_out[0] | |
else: | |
to_out = self.to_out | |
self.counter = 0 #time | |
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): #self_attention | |
x = hidden_states.clone() | |
context = encoder_hidden_states | |
is_cross = context is not None | |
if is_cross is False: | |
if controller.layer_fusion is not None and (mask_time[0] < self.counter < mask_time[1]): | |
b, i, j = x.shape | |
H = W = int(math.sqrt(i)) | |
x_old = x.clone() | |
x = x.reshape(b, H, W, j) | |
new_mask = controller.layer_fusion.remove_mask | |
if new_mask is not None: | |
new_mask[new_mask>0] = 1 | |
new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda() | |
new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1) | |
if (refine_time[0] < self.counter <= refine_time[1]) and controller.layer_fusion.refine_mask is not None: | |
new_mask = controller.layer_fusion.new_mask | |
new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda() | |
new_mask = (1 - new_mask).reshape(1, H, W).unsqueeze(-1) | |
idx = 1 #inpaiint_idx:bg | |
x[int(b/2)+idx, :, :] = (x[int(b/2)+idx, :, :]*new_mask[0]) | |
x = x.reshape(b, i, j) | |
if is_cross: | |
q = self.to_q(x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
else: | |
context = x | |
q = self.to_q(hidden_states) | |
k = self.to_k(x) | |
v = self.to_v(hidden_states) | |
q = self.head_to_batch_dim(q) | |
k = self.head_to_batch_dim(k) | |
v = self.head_to_batch_dim(v) | |
if hasattr(controller, 'count_layers'): | |
controller.count_layers(place_in_unet,is_cross) | |
sim = torch.einsum("b i d, b j d -> b i j", q.clone(), k.clone()) * self.scale | |
attn = sim.softmax(dim=-1) | |
out = torch.einsum("b i j, b j d -> b i d", attn, v) | |
out = self.batch_to_head_dim(out) | |
global global_cnt | |
self.counter += 1 | |
return to_out(out) | |
return forward | |
def register_recr(net_, count, place_in_unet): | |
if net_.__class__.__name__ == 'Attention': | |
net_.forward = ca_forward(net_, place_in_unet) | |
return count + 1 | |
elif hasattr(net_, 'children'): | |
for net__ in net_.children(): | |
count = register_recr(net__, count, place_in_unet) | |
return count | |
cross_att_count = 0 | |
sub_nets = model.unet.named_children() | |
for net in sub_nets: | |
if "down" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "down") | |
elif "up" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "up") | |
elif "mid" in net[0]: | |
cross_att_count += register_recr(net[1], 0, "mid") | |
controller.num_att_layers = cross_att_count | |
class DesignEdit(): | |
def __init__(self, pretrained_model_path="/home/jyr/model/stable-diffusion-xl-base-1.0"): | |
self.model_dtype = "fp16" | |
self.pretrained_model_path=pretrained_model_path | |
self.num_ddim_steps = 50 | |
self.mask_time = [0, 40] | |
self.op_list = {} | |
self.attend_scale = {} | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
if self.model_dtype == "fp16": | |
torch_dtype = torch.float16 | |
elif self.model_dtype == "fp32": | |
torch_dtype = torch.float32 | |
self.pipe = sdxl.from_pretrained(self.pretrained_model_path, torch_dtype=torch_dtype, use_safetensors=True, variant=self.model_dtype,scheduler=scheduler) | |
def init_model(self, num_ddim_steps=50): | |
device = torch.device('cuda:0') | |
self.pipe.to(device) | |
inversion = Inversion(self.pipe,num_ddim_steps) | |
return self.pipe, inversion | |
def run_remove(self, original_image=None, mask_1=None, mask_2=None, mask_3=None, refine_mask=None, | |
ori_1=None, ori_2=None, ori_3=None, | |
prompt="", save_dir="./tmp", mode='removal',): | |
# 01-1: | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
if original_image is None: | |
original_image = ori_1 if ori_1 is not None else ori_2 if ori_2 is not None else ori_3 | |
op_list = None | |
attend_scale = 20 | |
sample_ref_match={0 : 0, 1 : 0} | |
ori_shape = original_image.shape | |
# 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask | |
image_gt = Image.fromarray(original_image).resize((1024, 1024)) | |
image_gt = np.stack([np.array(image_gt)]) | |
mask_list = [mask_1, mask_2, mask_3] | |
remove_mask = utils.attend_mask(utils.add_masks_resized(mask_list), attend_scale=attend_scale) # numpy to tensor | |
fg_mask_list = None | |
refine_mask = utils.attend_mask(utils.convert_and_resize_mask(refine_mask)) if refine_mask is not None else None | |
# 01-3: prepare: prompts, blend_time, refine_time | |
prompts = len(sample_ref_match)*[prompt] # 2 | |
blend_time = [0, 41] | |
refine_time = [0, 25] | |
# 02: invert | |
_, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) | |
# 03: init layer_fusion and controller | |
lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, refine_mask=refine_mask, | |
blend_time=blend_time, mode=mode, op_list=op_list) | |
controller = Control(layer_fusion=lb) | |
register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) | |
# 04: generate images | |
images = self.ldm_model(controller=controller, prompt=prompts, | |
latents=x_t, x_stars=x_stars, | |
negative_prompt_embeds=prompt_embeds, | |
negative_pooled_prompt_embeds=pooled_prompt_embeds, | |
sample_ref_match=sample_ref_match) | |
folder = None | |
utils.view_images(images, folder=folder) | |
return [cv2.resize(images[1], (ori_shape[1], ori_shape[0]))] | |
def run_zooming(self, original_image, width_scale=1, height_scale=1, prompt="", save_dir="./tmp", mode='removal'): | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
# 01-1: | |
op_list = {0: ['zooming', [height_scale, width_scale]]} | |
ori_shape = original_image.shape | |
attend_scale = 30 | |
sample_ref_match = {0 : 0, 1 : 0} | |
# 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask | |
img_new, mask = utils.zooming(original_image, [height_scale, width_scale]) | |
img_new_copy = img_new.copy() | |
mask_copy = mask.copy() | |
image_gt = Image.fromarray(img_new).resize((1024, 1024)) | |
image_gt = np.stack([np.array(image_gt)]) | |
remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor | |
fg_mask_list = None | |
refine_mask = None | |
# 01-3: prepare: prompts, blend_time, refine_time | |
prompts = len(sample_ref_match)*[prompt] # 2 | |
blend_time = [0, 41] | |
refine_time = [0, 25] | |
# 02: invert | |
_, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) | |
# 03: init layer_fusion and controller | |
lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, | |
mode=mode, op_list=op_list) | |
controller = Control(layer_fusion=lb) | |
register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) | |
# 04: generate images | |
images = self.ldm_model(controller=controller, prompt=prompts, | |
latents=x_t, x_stars=x_stars, | |
negative_prompt_embeds=prompt_embeds, | |
negative_pooled_prompt_embeds=pooled_prompt_embeds, | |
sample_ref_match=sample_ref_match) | |
folder = None | |
utils.view_images(images, folder=folder) | |
resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) | |
return [resized_img], [img_new_copy], [mask_copy] | |
def run_panning(self, original_image, w_direction, w_scale, h_direction, h_scale, prompt="", save_dir="./tmp", mode='removal'): | |
# 01-1: prepare: op_list, attend_scale, sample_ref_match | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
ori_shape = original_image.shape | |
attend_scale = 30 | |
sample_ref_match = {0 : 0, 1 : 0} | |
# 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask | |
op_list = [[w_direction, w_scale], [h_direction, h_scale]] | |
img_new, mask = utils.panning(original_image, op_list=op_list) | |
img_new_copy = img_new.copy() | |
mask_copy = mask.copy() | |
image_gt = Image.fromarray(img_new).resize((1024, 1024)) | |
image_gt = np.stack([np.array(image_gt)]) | |
remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor | |
fg_mask_list = None | |
refine_mask = None | |
# 01-3: prepare: prompts, blend_time, refine_time | |
prompts = len(sample_ref_match)*[prompt] # 2 | |
blend_time = [0, 41] | |
refine_time = [0, 25] | |
# 02: invert | |
_, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1) | |
# 03: init layer_fusion and controller | |
lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, | |
mode=mode, op_list=op_list) | |
controller = Control(layer_fusion=lb) | |
register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) | |
# 04: generate images | |
images = self.ldm_model(controller=controller, prompt=prompts, | |
latents=x_t, x_stars=x_stars, | |
negative_prompt_embeds=prompt_embeds, | |
negative_pooled_prompt_embeds=pooled_prompt_embeds, | |
sample_ref_match=sample_ref_match) | |
folder = None | |
utils.view_images(images, folder=folder) | |
resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) | |
return [resized_img], [img_new_copy], [mask_copy] | |
# layer-wise multi-object editing | |
def process_layer_states(self, layer_states): | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
image_paths = [] | |
mask_paths = [] | |
op_list = [] | |
for state in layer_states: | |
img, mask, dx, dy, resize, w_flip, h_flip = state | |
if img is not None: | |
img = cv2.resize(img, (1024, 1024)) | |
mask = utils.convert_and_resize_mask(mask) | |
dx_command = ['right', dx] if dx > 0 else ['left', -dx] | |
dy_command = ['up', dy] if dy > 0 else ['down', -dy] | |
flip_code = None | |
if w_flip == "left/right" and h_flip == "down/up": | |
flip_code = -1 | |
elif w_flip == "left/right": | |
flip_code = 1 # 或者其他默认值,根据您的需要设置 | |
elif h_flip == "down/up": | |
flip_code = 0 | |
op_list.append([dx_command, dy_command]) | |
img, mask, _ = utils.resize_image_with_mask(img, mask, resize) | |
img, mask, _ = utils.flip_image_with_mask(img, mask, flip_code=flip_code) | |
image_paths.append(img) | |
mask_paths.append(utils.attend_mask(mask)) | |
sample_ref_match = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 3} | |
required_length = len(image_paths) + 3 | |
truncated_sample_ref_match = {k: sample_ref_match[k] for k in sorted(sample_ref_match.keys())[:required_length]} | |
return image_paths, mask_paths, op_list, truncated_sample_ref_match | |
def run_layer(self, bg_img, l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, | |
l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, | |
l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip, | |
bg_mask, l1_mask, l2_mask, l3_mask, | |
bg_ori=None, l1_ori=None, l2_ori=None, l3_ori=None, | |
prompt="", save_dir="./tmp", mode='layerwise'): | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
# 00: prepare: layer-wise states | |
bg_img = bg_ori if bg_ori is not None else bg_img | |
l1_img = l1_ori if l1_ori is not None else l1_img | |
l2_img = l2_ori if l2_ori is not None else l2_img | |
l3_img = l3_ori if l3_ori is not None else l3_img | |
for mask in [bg_mask, l1_mask, l2_mask, l3_mask]: | |
if mask is None: | |
mask = np.zeros((1024, 1024), dtype=np.uint8) | |
else: | |
mask = utils.convert_and_resize_mask(mask) | |
l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip] | |
l2_state = [l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip] | |
l3_state = [l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip] | |
ori_shape = bg_img.shape | |
image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state, l2_state, l3_state]) | |
if image_paths == []: | |
mode = "removal" | |
# 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask | |
attend_scale = 20 | |
image_gt = [bg_img] + image_paths | |
image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt] | |
image_gt = np.stack(image_gt) | |
remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale) | |
refine_mask = None | |
# 01-2: prepare: promptrun_masks, blend_time, refine_time | |
prompts = len(sample_ref_match)*[prompt] # 2 | |
blend_time = [0, 41] | |
refine_time = [0, 25] | |
attend_scale = [] | |
# 02: invert | |
_, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt)) | |
# 03: init layer_fusion and controller | |
lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask, | |
mode=mode, op_list=op_list) | |
controller = Control(layer_fusion=lb) | |
register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) | |
# 04: generate images | |
images = self.ldm_model(controller=controller, prompt=prompts, | |
latents=x_t, x_stars=x_stars, | |
negative_prompt_embeds=prompt_embeds, | |
negative_pooled_prompt_embeds=pooled_prompt_embeds, | |
sample_ref_match=sample_ref_match) | |
folder = None | |
utils.view_images(images, folder=folder) | |
if mode == 'removal': | |
resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0])) | |
else: | |
resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0])) | |
return [resized_img] | |
def run_moving(self, bg_img, bg_ori, bg_mask, l1_dx, l1_dy, l1_resize, | |
l1_w_flip=None, l1_h_flip=None, selected_points=None, | |
prompt="", save_dir="./tmp", mode='layerwise'): | |
self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps) | |
# 00: prepare: layer-wise states | |
bg_img = bg_ori if bg_ori is not None else bg_img | |
l1_img = bg_img | |
if bg_mask is None: | |
bg_mask = np.zeros((1024, 1024), dtype=np.uint8) | |
else: | |
bg_mask = utils.convert_and_resize_mask(bg_mask) | |
l1_mask = bg_mask | |
l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip] | |
ori_shape = bg_img.shape | |
image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state]) | |
# 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask | |
attend_scale = 20 | |
image_gt = [bg_img] + image_paths | |
image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt] | |
image_gt = np.stack(image_gt) | |
remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale) | |
refine_mask = None | |
# 01-2: prepare: promptrun_masks, blend_time, refine_time | |
prompts = len(sample_ref_match)*[prompt] # 2 | |
blend_time = [0, 41] | |
refine_time = [0, 25] | |
attend_scale = [] | |
# 02: invert | |
_, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt)) | |
# 03: init layer_fusion and controller | |
lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask, | |
mode=mode, op_list=op_list) | |
controller = Control(layer_fusion=lb) | |
register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time) | |
# 04: generate images | |
images = self.ldm_model(controller=controller, prompt=prompts, | |
latents=x_t, x_stars=x_stars, | |
negative_prompt_embeds=prompt_embeds, | |
negative_pooled_prompt_embeds=pooled_prompt_embeds, | |
sample_ref_match=sample_ref_match) | |
folder = None | |
utils.view_images(images, folder=folder) | |
resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0])) | |
return [resized_img] | |
# turn mask to 1024x1024 unit-8 | |
def run_mask(self, mask_1, mask_2, mask_3, mask_4): | |
mask_list = [mask_1, mask_2, mask_3, mask_4] | |
final_mask = utils.add_masks_resized(mask_list) | |
return final_mask |