| import torch | |
| import comfy.model_base | |
| def sdxl_encode_adm_patched(self, **kwargs): | |
| clip_pooled = kwargs["pooled_output"] | |
| width = kwargs.get("width", 768) | |
| height = kwargs.get("height", 768) | |
| crop_w = kwargs.get("crop_w", 0) | |
| crop_h = kwargs.get("crop_h", 0) | |
| target_width = kwargs.get("target_width", width) | |
| target_height = kwargs.get("target_height", height) | |
| if kwargs.get("prompt_type", "") == "negative": | |
| width *= 0.8 | |
| height *= 0.8 | |
| elif kwargs.get("prompt_type", "") == "positive": | |
| width *= 1.5 | |
| height *= 1.5 | |
| out = [] | |
| out.append(self.embedder(torch.Tensor([height]))) | |
| out.append(self.embedder(torch.Tensor([width]))) | |
| out.append(self.embedder(torch.Tensor([crop_h]))) | |
| out.append(self.embedder(torch.Tensor([crop_w]))) | |
| out.append(self.embedder(torch.Tensor([target_height]))) | |
| out.append(self.embedder(torch.Tensor([target_width]))) | |
| flat = torch.flatten(torch.cat(out))[None, ] | |
| return torch.cat((clip_pooled.to(flat.device), flat), dim=1) | |
| def patch_negative_adm(): | |
| comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched | |