from email.policy import default import gradio as gr import numpy as np import torch import requests import random import os import sys import pickle from PIL import Image from tqdm.auto import tqdm from datetime import datetime from utils.gradio_utils import is_torch2_available if is_torch2_available(): from utils.gradio_utils import \ AttnProcessor2_0 as AttnProcessor # from utils.gradio_utils import SpatialAttnProcessor2_0 else: from utils.gradio_utils import AttnProcessor import diffusers from diffusers import StableDiffusionXLPipeline from utils import PhotoMakerStableDiffusionXLPipeline from diffusers import DDIMScheduler import torch.nn.functional as F from utils.gradio_utils import cal_attn_mask_xl import copy import os from diffusers.utils import load_image from utils.utils import get_comic from utils.style_template import styles image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" os.environ["no_proxy"] = "localhost,127.0.0.1,::1" STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Japanese Anime" global models_dict use_va = False models_dict = { # "Juggernaut": "RunDiffusion/Juggernaut-XL-v8", # "RealVision": "SG161222/RealVisXL_V4.0" , # "SDXL":"stabilityai/stable-diffusion-xl-base-1.0" , "Unstable": "/mnt/bn/yupengdata2/projects/PhotoMaker/sdxl-unstable-diffusers-y" } photomaker_path = "/mnt/bn/yupengdata2/projects/PhotoMaker/photomaker-v1.bin" MAX_SEED = np.iinfo(np.int32).max def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def set_text_unfinished(): return gr.update(visible=True, value="

(Not Finished) Generating ··· The intermediate results will be shown.

") def set_text_finished(): return gr.update(visible=True, value="

Generation Finished

") ################################################# def get_image_path_list(folder_name): image_basename_list = os.listdir(folder_name) image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) return image_path_list ################################################# class SpatialAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. text_context_len (`int`, defaults to 77): The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.device = device self.dtype = dtype self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.total_length = id_length + 1 self.id_length = id_length self.id_bank = {} def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) # 生成一个0到1之间的随机数 global total_count,attn_count,cur_step,mask1024,mask4096 global sa32, sa64 global write global height,width if write: # print(f"white:{cur_step}") self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]] else: encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:])) # 判断随机数是否大于0.5 if cur_step <5: hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb) else: # 256 1024 4096 random_number = random.random() if cur_step <20: rand_num = 0.3 else: rand_num = 0.1 # print(f"hidden state shape {hidden_states.shape[1]}") if random_number > rand_num: # print("mask shape",mask1024.shape,mask4096.shape) if not write: if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:] else: attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:] else: # print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length] else: attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length] # print(attention_mask.shape) # print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb) else: hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb) attn_count +=1 if attn_count == total_count: attn_count = 0 cur_step += 1 mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype) return hidden_states def __call1__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): # print("hidden state shape",hidden_states.shape,self.id_length) residual = hidden_states # if encoder_hidden_states is not None: # raise Exception("not implement") if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: total_batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) total_batch_size,nums_token,channel = hidden_states.shape img_nums = total_batch_size//2 hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel) batch_size, sequence_length, _ = hidden_states.shape if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # print(key.shape,value.shape,query.shape,attention_mask.shape) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 #print(query.shape,key.shape,value.shape,attention_mask.shape) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # if input_ndim == 4: # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # if attn.residual_connection: # tile_hidden_states = tile_hidden_states + residual if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor # print(hidden_states.shape) return hidden_states def __call2__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, channel = ( hidden_states.shape ) # print(hidden_states.shape) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def set_attention_processor(unet,id_length,is_ipadapter = False): global total_count total_count = 0 attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if name.startswith("up_blocks") : attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length) total_count +=1 else: attn_procs[name] = AttnProcessor() else: if is_ipadapter: attn_procs[name] = IPAttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1, num_tokens=4, ).to(unet.device, dtype=torch.float16) else: attn_procs[name] = AttnProcessor() unet.set_attn_processor(copy.deepcopy(attn_procs)) print("successsfully load paired self-attention") print(f"number of the processor : {total_count}") ################################################# ################################################# canvas_html = "
" load_js = """ async () => { const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_colors = """ async (canvasData) => { const canvasEl = document.getElementById("canvas-root"); return [canvasEl._data] } """ css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%}