from typing import Literal, Optional, Union, List import yaml from pathlib import Path from pydantic import BaseModel, root_validator import torch import copy ACTION_TYPES = Literal[ "erase", "enhance", ] # XL は二種類必要なので class PromptEmbedsXL: text_embeds: torch.FloatTensor pooled_embeds: torch.FloatTensor def __init__(self, *args) -> None: self.text_embeds = args[0] self.pooled_embeds = args[1] # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL] class PromptEmbedsCache: # 使いまわしたいので prompts: dict[str, PROMPT_EMBEDDING] = {} def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None: self.prompts[__name] = __value def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]: if __name in self.prompts: return self.prompts[__name] else: return None class PromptSettings(BaseModel): # yaml のやつ target: str positive: str = None # if None, target will be used unconditional: str = "" # default is "" neutral: str = None # if None, unconditional will be used action: ACTION_TYPES = "erase" # default is "erase" guidance_scale: float = 1.0 # default is 1.0 resolution: int = 512 # default is 512 dynamic_resolution: bool = False # default is False batch_size: int = 1 # default is 1 dynamic_crops: bool = False # default is False. only used when model is XL @root_validator(pre=True) def fill_prompts(cls, values): keys = values.keys() if "target" not in keys: raise ValueError("target must be specified") if "positive" not in keys: values["positive"] = values["target"] if "unconditional" not in keys: values["unconditional"] = "" if "neutral" not in keys: values["neutral"] = values["unconditional"] return values class PromptEmbedsPair: target: PROMPT_EMBEDDING # not want to generate the concept positive: PROMPT_EMBEDDING # generate the concept unconditional: PROMPT_EMBEDDING # uncondition (default should be empty) neutral: PROMPT_EMBEDDING # base condition (default should be empty) guidance_scale: float resolution: int dynamic_resolution: bool batch_size: int dynamic_crops: bool loss_fn: torch.nn.Module action: ACTION_TYPES def __init__( self, loss_fn: torch.nn.Module, target: PROMPT_EMBEDDING, positive: PROMPT_EMBEDDING, unconditional: PROMPT_EMBEDDING, neutral: PROMPT_EMBEDDING, settings: PromptSettings, ) -> None: self.loss_fn = loss_fn self.target = target self.positive = positive self.unconditional = unconditional self.neutral = neutral self.guidance_scale = settings.guidance_scale self.resolution = settings.resolution self.dynamic_resolution = settings.dynamic_resolution self.batch_size = settings.batch_size self.dynamic_crops = settings.dynamic_crops self.action = settings.action def _erase( self, target_latents: torch.FloatTensor, # "van gogh" positive_latents: torch.FloatTensor, # "van gogh" unconditional_latents: torch.FloatTensor, # "" neutral_latents: torch.FloatTensor, # "" ) -> torch.FloatTensor: """Target latents are going not to have the positive concept.""" return self.loss_fn( target_latents, neutral_latents - self.guidance_scale * (positive_latents - unconditional_latents) ) def _enhance( self, target_latents: torch.FloatTensor, # "van gogh" positive_latents: torch.FloatTensor, # "van gogh" unconditional_latents: torch.FloatTensor, # "" neutral_latents: torch.FloatTensor, # "" ): """Target latents are going to have the positive concept.""" return self.loss_fn( target_latents, neutral_latents + self.guidance_scale * (positive_latents - unconditional_latents) ) def loss( self, **kwargs, ): if self.action == "erase": return self._erase(**kwargs) elif self.action == "enhance": return self._enhance(**kwargs) else: raise ValueError("action must be erase or enhance") def load_prompts_from_yaml(path, target, positive, negative, attributes = []): with open(path, "r") as f: prompts = yaml.safe_load(f) new = [] for prompt in prompts: copy_ = copy.deepcopy(prompt) copy_['target'] = target copy_['positive'] = positive copy_['neutral'] = target copy_['unconditional'] = negative new.append(copy_) prompts = new print(prompts) if len(prompts) == 0: raise ValueError("prompts file is empty") if len(attributes)!=0: newprompts = [] for i in range(len(prompts)): for att in attributes: copy_ = copy.deepcopy(prompts[i]) copy_['target'] = att + ' ' + copy_['target'] copy_['positive'] = att + ' ' + copy_['positive'] copy_['neutral'] = att + ' ' + copy_['neutral'] copy_['unconditional'] = att + ' ' + copy_['unconditional'] newprompts.append(copy_) else: newprompts = copy.deepcopy(prompts) print(newprompts) print(len(prompts), len(newprompts)) prompt_settings = [PromptSettings(**prompt) for prompt in newprompts] return prompt_settings