|
import copy |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
from torch.nn.modules.utils import _pair |
|
|
|
from ..log import log |
|
|
|
|
|
class MTB_VaeDecode: |
|
"""Wrapper for the 2 core decoders but also adding the sd seamless hack, taken from: FlyingFireCo/tiled_ksampler""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"samples": ("LATENT",), |
|
"vae": ("VAE",), |
|
"seamless_model": ("BOOLEAN", {"default": False}), |
|
"use_tiling_decoder": ("BOOLEAN", {"default": True}), |
|
"tile_size": ( |
|
"INT", |
|
{"default": 512, "min": 320, "max": 4096, "step": 64}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "decode" |
|
|
|
CATEGORY = "mtb/decode" |
|
|
|
def decode( |
|
self, |
|
vae, |
|
samples, |
|
seamless_model, |
|
use_tiling_decoder=True, |
|
tile_size=512, |
|
): |
|
if seamless_model: |
|
if use_tiling_decoder: |
|
log.error( |
|
"You cannot use seamless mode with tiling decoder together, skipping tiling." |
|
) |
|
use_tiling_decoder = False |
|
for layer in [ |
|
layer |
|
for layer in vae.first_stage_model.modules() |
|
if isinstance(layer, torch.nn.Conv2d) |
|
]: |
|
layer.padding_mode = "circular" |
|
if use_tiling_decoder: |
|
return ( |
|
vae.decode_tiled( |
|
samples["samples"], |
|
tile_x=tile_size // 8, |
|
tile_y=tile_size // 8, |
|
), |
|
) |
|
else: |
|
return (vae.decode(samples["samples"]),) |
|
|
|
|
|
def conv_forward(lyr, tensor, weight, bias): |
|
step = lyr.timestep |
|
if (lyr.paddingStartStep < 0 or step >= lyr.paddingStartStep) and ( |
|
lyr.paddingStopStep < 0 or step <= lyr.paddingStopStep |
|
): |
|
working = F.pad(tensor, lyr.paddingX, mode=lyr.padding_modeX) |
|
working = F.pad(working, lyr.paddingY, mode=lyr.padding_modeY) |
|
else: |
|
working = F.pad(tensor, lyr.paddingX, mode="constant") |
|
working = F.pad(working, lyr.paddingY, mode="constant") |
|
|
|
lyr.timestep += 1 |
|
|
|
return F.conv2d( |
|
working, weight, bias, lyr.stride, _pair(0), lyr.dilation, lyr.groups |
|
) |
|
|
|
|
|
class MTB_ModelPatchSeamless: |
|
"""Uses the stable diffusion 'hack' to infer seamless images by setting the model layers padding mode to circular (experimental)""" |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"model": ("MODEL",), |
|
"startStep": ("INT", {"default": 0}), |
|
"stopStep": ("INT", {"default": 999}), |
|
"tilingX": ( |
|
"BOOLEAN", |
|
{"default": True}, |
|
), |
|
"tilingY": ( |
|
"BOOLEAN", |
|
{"default": True}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("MODEL", "MODEL") |
|
RETURN_NAMES = ( |
|
"Original Model (passthrough)", |
|
"Patched Model", |
|
) |
|
FUNCTION = "hack" |
|
|
|
CATEGORY = "mtb/textures" |
|
|
|
def apply_circular(self, model, startStep, stopStep, x, y): |
|
for layer in [ |
|
layer |
|
for layer in model.modules() |
|
if isinstance(layer, torch.nn.Conv2d) |
|
]: |
|
layer.padding_modeX = "circular" if x else "constant" |
|
layer.padding_modeY = "circular" if y else "constant" |
|
layer.paddingX = ( |
|
layer._reversed_padding_repeated_twice[0], |
|
layer._reversed_padding_repeated_twice[1], |
|
0, |
|
0, |
|
) |
|
layer.paddingY = ( |
|
0, |
|
0, |
|
layer._reversed_padding_repeated_twice[2], |
|
layer._reversed_padding_repeated_twice[3], |
|
) |
|
layer.paddingStartStep = startStep |
|
layer.paddingStopStep = stopStep |
|
layer.timestep = 0 |
|
layer._conv_forward = conv_forward.__get__(layer, torch.nn.Conv2d) |
|
|
|
return model |
|
|
|
def hack( |
|
self, |
|
model, |
|
startStep, |
|
stopStep, |
|
tilingX, |
|
tilingY, |
|
): |
|
hacked_model = copy.deepcopy(model) |
|
self.apply_circular( |
|
hacked_model.model, startStep, stopStep, tilingX, tilingY |
|
) |
|
return (model, hacked_model) |
|
|
|
|
|
__nodes__ = [MTB_ModelPatchSeamless, MTB_VaeDecode] |
|
|