import os import gradio as gr import numpy as np import random import spaces import torch import json import logging from diffusers import DiffusionPipeline from huggingface_hub import login import time from datetime import datetime from io import BytesIO # from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.attention_processor import AttnProcessor2_0 import torch.nn.functional as F import re import json # 登录 Hugging Face Hub HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) import diffusers print(diffusers.__version__) # 初始化 dtype = torch.float16 # 您可以根据需要调整数据类型 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" # 替换为您的模型 # 加载管道 pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) MAX_SEED = 2**32 - 1 class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") # 定义位置、偏移和区域的映射 valid_locations = { # x, y in 90*90 'in the center': (45, 45), 'on the left': (15, 45), 'on the right': (75, 45), 'on the top': (45, 15), 'on the bottom': (45, 75), 'on the top-left': (15, 15), 'on the top-right': (75, 15), 'on the bottom-left': (15, 75), 'on the bottom-right': (75, 75) } valid_offsets = { # x, y in 90*90 'no offset': (0, 0), 'slightly to the left': (-10, 0), 'slightly to the right': (10, 0), 'slightly to the upper': (0, -10), 'slightly to the lower': (0, 10), 'slightly to the upper-left': (-10, -10), 'slightly to the upper-right': (10, -10), 'slightly to the lower-left': (-10, 10), 'slightly to the lower-right': (10, 10) } valid_areas = { # w, h in 90*90 "a small square area": (50, 50), "a small vertical area": (40, 60), "a small horizontal area": (60, 40), "a medium-sized square area": (60, 60), "a medium-sized vertical area": (50, 80), "a medium-sized horizontal area": (80, 50), "a large square area": (70, 70), "a large vertical area": (60, 90), "a large horizontal area": (90, 60) } # 解析角色位置的函数 def parse_character_position(character_position): # 定义正则表达式模式 location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys()) offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys()) area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys()) # 提取位置 location_match = re.search(location_pattern, character_position, re.IGNORECASE) location = location_match.group(0) if location_match else 'in the center' # 提取偏移 offset_match = re.search(offset_pattern, character_position, re.IGNORECASE) offset = offset_match.group(0) if offset_match else 'no offset' # 提取区域 area_match = re.search(area_pattern, character_position, re.IGNORECASE) area = area_match.group(0) if area_match else 'a medium-sized square area' return { 'location': location, 'offset': offset, 'area': area } # 创建掩码的函数 def create_attention_mask(image_width, image_height, location, offset, area): # 图像在生成时通常会被缩放为 90x90,因此先定义一个基础尺寸 base_size = 90 # 获取位置坐标 loc_x, loc_y = valid_locations.get(location, (45, 45)) # 获取偏移量 offset_x, offset_y = valid_offsets.get(offset, (0, 0)) # 获取区域大小 area_width, area_height = valid_areas.get(area, (60, 60)) # 计算最终位置 final_x = loc_x + offset_x final_y = loc_y + offset_y # 将坐标和尺寸映射到实际图像尺寸 scale_x = image_width / base_size scale_y = image_height / base_size center_x = final_x * scale_x center_y = final_y * scale_y width = area_width * scale_x height = area_height * scale_y # 计算左上角和右下角坐标 x_start = int(max(center_x - width / 2, 0)) y_start = int(max(center_y - height / 2, 0)) x_end = int(min(center_x + width / 2, image_width)) y_end = int(min(center_y + height / 2, image_height)) # 创建掩码 mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda") mask[y_start:y_end, x_start:x_end] = 1.0 # 展平成一维 mask_flat = mask.view(-1) # 形状为 (image_height * image_width,) return mask_flat # 自定义注意力处理器 class CustomCrossAttentionProcessor(AttnProcessor2_0): def __init__(self, masks, adapter_names): super().__init__() self.masks = masks # 列表,包含每个角色的掩码 (shape: [key_length]) self.adapter_names = adapter_names # 列表,包含每个角色的 LoRA 适配器名称 def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, **kwargs, ): """ 自定义的注意力处理器,用于在注意力计算中应用角色掩码。 参数: attn: 注意力模块实例。 hidden_states: 输入的隐藏状态 (query)。 encoder_hidden_states: 编码器的隐藏状态 (key/value)。 attention_mask: 注意力掩码。 temb: 时间嵌入(可能不需要)。 **kwargs: 其他参数。 返回: 处理后的隐藏状态。 """ # 获取当前的 adapter_name adapter_name = getattr(attn, 'adapter_name', None) if adapter_name is None or adapter_name not in self.adapter_names: # 如果没有 adapter_name,或者不在我们的列表中,直接执行父类的 __call__ 方法 return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs) # 查找 adapter_name 对应的索引 idx = self.adapter_names.index(adapter_name) mask = self.masks[idx] # 获取对应的掩码 (shape: [key_length]) # 以下是 AttnProcessor2_0 的实现,我们在适当的位置加入自定义的掩码逻辑 residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) else: batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # 如果有 encoder_hidden_states,获取其形状 encoder_batch_size, key_length, _ = encoder_hidden_states.shape if attention_mask is not None: # 处理 attention_mask,如果需要的话 attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size) # attention_mask 的形状应为 (batch_size, attn.heads, query_length, key_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) else: # 如果没有 attention_mask,我们创建一个全 0 的掩码 attention_mask = torch.zeros( batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # 计算原始的注意力得分 # 我们需要在计算注意力得分前应用掩码 # 但由于 PyTorch 的 scaled_dot_product_attention 接受 attention_mask 参数,我们需要调整我们的掩码 # 创建自定义的 attention_mask # mask 的形状为 [key_length],需要调整为 (batch_size, attn.heads, 1, key_length) custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype) # 将有效位置设为 0,被掩蔽的位置设为 -1e9(对于 float16,使用 -65504) mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9 custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # 有效位置为 0,无效位置为 -1e9 # 将自定义掩码添加到 attention_mask attention_mask = attention_mask + custom_attention_mask # 计算注意力 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states # 替换注意力处理器的函数 def replace_attention_processors(pipe, masks, adapter_names): custom_processor = CustomCrossAttentionProcessor(masks, adapter_names) for name, module in pipe.transformer.named_modules(): if hasattr(module, 'attn'): module.attn.adapter_name = getattr(module, 'adapter_name', None) module.attn.processor = custom_processor if hasattr(module, 'cross_attn'): module.cross_attn.adapter_name = getattr(module, 'adapter_name', None) module.cross_attn.processor = custom_processor # 生成图像的函数 def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress): pipe.to(device) generator = torch.Generator(device=device).manual_seed(seed) with calculateDuration("Generating image"): # Generate image generated_image = pipe( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=steps, guidance_scale=cfg_scale, width=width, height=height, generator=generator, ).images[0] progress(99, "Generate success!") return generated_image # 主函数 @spaces.GPU @torch.inference_mode() def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)): # 解析角色提示词、位置和 LoRA 字符串 try: character_prompts = json.loads(character_prompts_json) character_positions = json.loads(character_positions_json) lora_strings = json.loads(lora_strings_json) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON input: {e}") # 确保提示词、位置和 LoRA 字符串的数量一致 if len(character_prompts) != len(character_positions) or len(character_prompts) != len(lora_strings): raise ValueError("The number of character prompts, positions, and LoRA strings must be the same.") # 角色的数量 num_characters = len(character_prompts) # Load LoRA weights with calculateDuration("Loading LoRA weights"): pipe.unload_lora_weights() adapter_names = [] for lora_info in lora_strings: lora_repo = lora_info.get("repo") weights = lora_info.get("weights") adapter_name = lora_info.get("adapter_name") if lora_repo and weights and adapter_name: # 调用 pipe.load_lora_weights() 方法加载权重 pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name) adapter_names.append(adapter_name) # 将 adapter_name 设置为模型的属性 setattr(pipe.transformer, 'adapter_name', adapter_name) else: raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.") adapter_weights = [lora_scale] * len(adapter_names) # 调用 pipeline.set_adapters 方法设置 adapter 和对应权重 pipe.set_adapters(adapter_names, adapter_weights=adapter_weights) # 确保 adapter_names 的数量与角色数量匹配 if len(adapter_names) != num_characters: raise ValueError("The number of LoRA adapters must match the number of characters.") # Set random seed for reproducibility if randomize_seed: with calculateDuration("Set random seed"): seed = random.randint(0, MAX_SEED) # 编码提示词 with calculateDuration("Encoding prompts"): # 编码背景提示词 bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device) bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0] bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output # 编码角色提示词 character_prompt_embeds = [] character_pooled_embeds = [] for prompt in character_prompts: char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device) char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0] char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output character_prompt_embeds.append(char_prompt_embeds) character_pooled_embeds.append(char_pooled_embeds) # 编码互动细节提示词 details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device) details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0] details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output # 合并背景和互动细节的嵌入 prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1) pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1) # 解析角色位置 character_infos = [] for position_str in character_positions: info = parse_character_position(position_str) character_infos.append(info) # 创建角色的掩码 masks = [] for info in character_infos: mask = create_attention_mask(width, height, info['location'], info['offset'], info['area']) masks.append(mask) # 替换注意力处理器 replace_attention_processors(pipe, masks, adapter_names) # Generate image final_image = generate_image_with_embeddings(prompt_embeddings, pooled_prompt_embeds, steps, seed, cfg_scale, width, height, progress) # 您可以在此处添加上传图片的代码 result = {"status": "success", "message": "Image generated"} progress(100, "Completed!") return final_image, seed, json.dumps(result) # Gradio 界面 css=""" #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("Flux with LoRA") with gr.Row(): with gr.Column(): prompt_bg = gr.Text(label="Background Prompt", placeholder="Enter background/scene prompt", lines=2) character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5) character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5) lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5) prompt_details = gr.Text(label="Interaction Details", placeholder="Enter interaction details between characters", lines=2) run_button = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.5) with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=512) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=512) with gr.Row(): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) upload_to_r2 = gr.Checkbox(label="Upload to R2", value=False) account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") with gr.Column(): result = gr.Image(label="Result", show_label=False) seed_output = gr.Text(label="Seed") json_text = gr.Text(label="Result JSON") inputs = [ prompt_bg, character_prompts, character_positions, lora_strings_json, prompt_details, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket ] outputs = [result, seed_output, json_text] run_button.click( fn=run_lora, inputs=inputs, outputs=outputs ) demo.queue().launch()