StupidGame's picture
Upload 1941 files
baa8e90
raw
history blame contribute delete
No virus
4.39 kB
import torch
class CDTuner:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"detail_1": ("FLOAT", {
"default": 0,
"min": -10,
"max": 10,
"step": 0.1
}),
"detail_2": ("FLOAT", {
"default": 0,
"min": -10,
"max": 10,
"step": 0.1
}),
"contrast_1": ("FLOAT", {
"default": 0,
"min": -20,
"max": 20,
"step": 0.1
}),
"start": ("INT", {
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
"end": ("INT", {
"default": 1000,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, detail_1, detail_2, contrast_1, start, end):
'''
detail_1: 最初のConv層のweightを減らしbiasを増やすことで、detailを増やす・・?
detail_2: 最後のConv層前のGroupNormの以下略
contrast_1: 最後のConv層のbiasの0チャンネル目を増やすことでコントラストを増やす・・・?
'''
new_model = model.clone()
ratios = fineman([detail_1, detail_2, contrast_1])
self.storedweights = {}
self.start = start
self.end = end
# unet計算前後のパッチ
def apply_cdtuner(model_function, kwargs):
if kwargs["timestep"][0] < (1000 - self.end) or kwargs["timestep"][0] > (1000 - self.start):
return model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
for i, name in enumerate(ADJUSTS):
# 元の重みをロード
self.storedweights[name] = getset_nested_module_tensor(True, new_model, name).clone()
if 4 > i:
new_weight = self.storedweights[name] * ratios[i]
else:
device = self.storedweights[name].device
dtype = self.storedweights[name].dtype
new_weight = self.storedweights[name] + torch.tensor(ratios[i], device=device, dtype=dtype)
# 重みを書き換え
getset_nested_module_tensor(False, new_model, name, new_tensor=new_weight)
retval = model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
# 重みを元に戻す
for name in ADJUSTS:
getset_nested_module_tensor(False, new_model, name, new_tensor=self.storedweights[name])
return retval
new_model.set_model_unet_function_wrapper(apply_cdtuner)
return (new_model, )
def getset_nested_module_tensor(clone, model, tensor_path, new_tensor=None):
sdmodules = tensor_path.split('.')
target_module = model
last_attr = None
for module_name in sdmodules if clone else sdmodules[:-1]:
if module_name.isdigit():
target_module = target_module[int(module_name)]
else:
target_module = getattr(target_module, module_name)
if clone:
return target_module
last_attr = sdmodules[-1]
setattr(target_module, last_attr, torch.nn.Parameter(new_tensor))
# なんでfineman?
def fineman(fine):
fine = [
1 - fine[0] * 0.01,
1 + fine[0] * 0.02,
1 - fine[1] * 0.01,
1 + fine[1] * 0.02,
[fine[2] * 0.02, 0, 0, 0]
]
return fine
ADJUSTS = [
"model.diffusion_model.input_blocks.0.0.weight",
"model.diffusion_model.input_blocks.0.0.bias",
"model.diffusion_model.out.0.weight",
"model.diffusion_model.out.0.bias",
"model.diffusion_model.out.2.bias",
]
NODE_CLASS_MAPPINGS = {
"CDTuner": CDTuner,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CDTuner": "Apply CDTuner",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]