Spaces:
Running
Running
| from functools import partial | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import weakref | |
| from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING | |
| from diffusers.models.transformers.transformer_flux import FluxTransformerBlock | |
| from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer | |
| from toolkit import train_tools | |
| from toolkit.prompt_utils import PromptEmbeds | |
| from diffusers import Transformer2DModel | |
| from toolkit.dequantize import patch_dequantization_on_save | |
| if TYPE_CHECKING: | |
| from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline | |
| from toolkit.custom_adapter import CustomAdapter | |
| LLM = Union[Qwen2Model, LlamaModel] | |
| LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer] | |
| def new_context_embedder_forward(self, x): | |
| if self._adapter_ref().is_active: | |
| x = self._context_embedder_ref()(x) | |
| else: | |
| x = self._orig_forward(x) | |
| return x | |
| def new_block_forward( | |
| self: FluxTransformerBlock, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temb: torch.Tensor, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self._adapter_ref().is_active: | |
| return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) | |
| else: | |
| return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) | |
| class LLMAdapter(torch.nn.Module): | |
| def __init__( | |
| self, | |
| adapter: 'CustomAdapter', | |
| sd: 'StableDiffusion', | |
| llm: LLM, | |
| tokenizer: LLMTokenizer, | |
| num_cloned_blocks: int = 0, | |
| ): | |
| super(LLMAdapter, self).__init__() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.sd_ref: weakref.ref = weakref.ref(sd) | |
| self.llm_ref: weakref.ref = weakref.ref(llm) | |
| self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) | |
| self.num_cloned_blocks = num_cloned_blocks | |
| self.apply_embedding_mask = False | |
| # make sure we can pad | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| self.system_prompt = "" | |
| # self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> " | |
| # determine length of system prompt | |
| sys_prompt_tokenized = tokenizer( | |
| [self.system_prompt], | |
| padding="longest", | |
| return_tensors="pt", | |
| ) | |
| sys_prompt_tokenized_ids = sys_prompt_tokenized.input_ids[0] | |
| self.system_prompt_length = sys_prompt_tokenized_ids.shape[0] | |
| print(f"System prompt length: {self.system_prompt_length}") | |
| self.hidden_size = llm.config.hidden_size | |
| blocks = [] | |
| if sd.is_flux: | |
| self.apply_embedding_mask = True | |
| self.context_embedder = nn.Linear( | |
| self.hidden_size, sd.unet.inner_dim) | |
| self.sequence_length = 512 | |
| sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward | |
| sd.unet.context_embedder.forward = partial( | |
| new_context_embedder_forward, sd.unet.context_embedder) | |
| sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder) | |
| # add a is active property to the context embedder | |
| sd.unet.context_embedder._adapter_ref = self.adapter_ref | |
| for idx in range(self.num_cloned_blocks): | |
| block = FluxTransformerBlock( | |
| dim=sd.unet.inner_dim, | |
| num_attention_heads=24, | |
| attention_head_dim=128, | |
| ) | |
| # patch it in case it is quantized | |
| patch_dequantization_on_save(sd.unet.transformer_blocks[idx]) | |
| state_dict = sd.unet.transformer_blocks[idx].state_dict() | |
| for key, value in state_dict.items(): | |
| block.state_dict()[key].copy_(value) | |
| blocks.append(block) | |
| orig_block = sd.unet.transformer_blocks[idx] | |
| orig_block._orig_forward = orig_block.forward | |
| orig_block.forward = partial( | |
| new_block_forward, orig_block) | |
| orig_block._new_block_ref = weakref.ref(block) | |
| orig_block._adapter_ref = self.adapter_ref | |
| elif sd.is_lumina2: | |
| self.context_embedder = nn.Linear( | |
| self.hidden_size, sd.unet.hidden_size) | |
| self.sequence_length = 256 | |
| else: | |
| raise ValueError( | |
| "llm adapter currently only supports flux or lumina2") | |
| self.blocks = nn.ModuleList(blocks) | |
| def _get_prompt_embeds( | |
| self, | |
| prompt: Union[str, List[str]], | |
| max_sequence_length: int = 256, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| tokenizer = self.tokenizer_ref() | |
| text_encoder = self.llm_ref() | |
| device = text_encoder.device | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length + self.system_prompt_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(device) | |
| prompt_attention_mask = text_inputs.attention_mask.to(device) | |
| # remove the system prompt from the input and attention mask | |
| prompt_embeds = text_encoder( | |
| text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True | |
| ) | |
| prompt_embeds = prompt_embeds.hidden_states[-1] | |
| prompt_embeds = prompt_embeds[:, self.system_prompt_length:] | |
| prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] | |
| dtype = text_encoder.dtype | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| return prompt_embeds, prompt_attention_mask | |
| # make a getter to see if is active | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |
| def encode_text(self, prompt): | |
| prompt = prompt if isinstance(prompt, list) else [prompt] | |
| prompt = [self.system_prompt + p for p in prompt] | |
| # prompt = [self.system_prompt + p for p in prompt] | |
| prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( | |
| prompt=prompt, | |
| max_sequence_length=self.sequence_length, | |
| ) | |
| prompt_embeds = PromptEmbeds( | |
| prompt_embeds, | |
| attention_mask=prompt_attention_mask, | |
| ).detach() | |
| return prompt_embeds | |
| def forward(self, input): | |
| return input | |