from email.policy import default import gradio as gr import numpy as np import torch from huggingface_hub import hf_hub_download import requests import random import os import sys import pickle from PIL import Image from tqdm.auto import tqdm from datetime import datetime from utils.gradio_utils import is_torch2_available if is_torch2_available(): from utils.gradio_utils import \ AttnProcessor2_0 as AttnProcessor else: from utils.gradio_utils import AttnProcessor import diffusers from diffusers import StableDiffusionXLPipeline from utils import PhotoMakerStableDiffusionXLPipeline from diffusers import DDIMScheduler import torch.nn.functional as F from utils.gradio_utils import cal_attn_mask_xl import copy import os from diffusers.utils import load_image from utils.utils import get_comic from utils.style_template import styles image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" os.environ["no_proxy"] = "localhost,127.0.0.1,::1" STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Japanese Anime" global models_dict models_dict = { # "Juggernaut": "RunDiffusion/Juggernaut-XL-v9", "RealVision": "SG161222/RealVisXL_V4.0" , "SDXL": "stabilityai/stable-diffusion-xl-base-1.0" , "Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y" } photomaker_path = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") MAX_SEED = np.iinfo(np.int32).max def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def set_text_unfinished(): return gr.update(visible=True, value="

(Not Finished) Generating ··· The intermediate results will be shown.

") def set_text_finished(): return gr.update(visible=True, value="

Generation Finished

") ################################################# def get_image_path_list(folder_name): image_basename_list = os.listdir(folder_name) image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) return image_path_list ################################################# class SpatialAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. text_context_len (`int`, defaults to 77): The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.device = device self.dtype = dtype self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.total_length = id_length + 1 self.id_length = id_length self.id_bank = {} def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) # 生成一个0到1之间的随机数 global total_count,attn_count,cur_step,mask1024,mask4096 global sa32, sa64 global write global height,width if write: # print(f"white:{cur_step}") self.id_bank[cur_step] = [hidden_states[:self.id_length].clone(), hidden_states[self.id_length:].clone()] else: encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:])) # 判断随机数是否大于0.5 if cur_step <1: hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb) else: # 256 1024 4096 random_number = random.random() if cur_step <20: rand_num = 0.3 else: rand_num = 0.1 # print(f"hidden state shape {hidden_states.shape[1]}") if random_number > rand_num: # print("mask shape",mask1024.shape,mask4096.shape) if not write: if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:] else: attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:] else: # print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length] else: attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length] # print(attention_mask.shape) # print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb) else: hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb) attn_count +=1 if attn_count == total_count: attn_count = 0 cur_step += 1 mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype) return hidden_states def __call1__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): # print("hidden state shape",hidden_states.shape,self.id_length) residual = hidden_states # if encoder_hidden_states is not None: # raise Exception("not implement") if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: total_batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) total_batch_size,nums_token,channel = hidden_states.shape img_nums = total_batch_size//2 hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel) batch_size, sequence_length, _ = hidden_states.shape if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # print(key.shape,value.shape,query.shape,attention_mask.shape) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 #print(query.shape,key.shape,value.shape,attention_mask.shape) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # if input_ndim == 4: # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # if attn.residual_connection: # tile_hidden_states = tile_hidden_states + residual if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor # print(hidden_states.shape) return hidden_states def __call2__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, channel = ( hidden_states.shape ) # print(hidden_states.shape) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def set_attention_processor(unet,id_length,is_ipadapter = False): global attn_procs attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if name.startswith("up_blocks") : attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length) else: attn_procs[name] = AttnProcessor() else: if is_ipadapter: attn_procs[name] = IPAttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1, num_tokens=4, ).to(unet.device, dtype=torch.float16) else: attn_procs[name] = AttnProcessor() unet.set_attn_processor(copy.deepcopy(attn_procs)) ################################################# ################################################# canvas_html = "
" load_js = """ async () => { const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_colors = """ async (canvasData) => { const canvasEl = document.getElementById("canvas-root"); return [canvasEl._data] } """ css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%}