|
import torch |
|
import os |
|
from torch import nn |
|
from safetensors.torch import load_file |
|
import torch.nn.functional as F |
|
from diffusers import AutoencoderTiny |
|
from transformers import SiglipImageProcessor, SiglipVisionModel |
|
import lpips |
|
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO |
|
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) |
|
self.norm1 = nn.GroupNorm(8, out_channels) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) |
|
self.norm2 = nn.GroupNorm(8, out_channels) |
|
self.skip = nn.Conv2d(in_channels, out_channels, |
|
1) if in_channels != out_channels else nn.Identity() |
|
|
|
def forward(self, x): |
|
identity = self.skip(x) |
|
x = self.conv1(x) |
|
x = self.norm1(x) |
|
x = F.silu(x) |
|
x = self.conv2(x) |
|
x = self.norm2(x) |
|
x = F.silu(x + identity) |
|
return x |
|
|
|
|
|
class DiffusionFeatureExtractor2(nn.Module): |
|
def __init__(self, in_channels=32): |
|
super().__init__() |
|
self.version = 2 |
|
|
|
|
|
self.up_path = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 64, 3, padding=1), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(64, 64), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(64, 64), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(64, 64), |
|
nn.Conv2d(64, 64, 3, padding=1), |
|
]) |
|
|
|
|
|
self.path2 = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 128, 3, padding=1), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(128, 128), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(128, 128), |
|
nn.Conv2d(128, 128, 3, padding=1), |
|
]) |
|
|
|
|
|
self.path3 = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 256, 3, padding=1), |
|
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), |
|
ResBlock(256, 256), |
|
nn.Conv2d(256, 256, 3, padding=1) |
|
]) |
|
|
|
|
|
self.path4 = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 512, 3, padding=1), |
|
ResBlock(512, 512), |
|
ResBlock(512, 512), |
|
nn.Conv2d(512, 512, 3, padding=1) |
|
]) |
|
|
|
|
|
self.path5 = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 512, 3, padding=1), |
|
ResBlock(512, 512), |
|
nn.AvgPool2d(2), |
|
ResBlock(512, 512), |
|
nn.Conv2d(512, 512, 3, padding=1) |
|
]) |
|
|
|
def forward(self, x): |
|
outputs = [] |
|
|
|
|
|
x1 = x |
|
for layer in self.up_path: |
|
x1 = layer(x1) |
|
outputs.append(x1) |
|
|
|
|
|
x2 = x |
|
for layer in self.path2: |
|
x2 = layer(x2) |
|
outputs.append(x2) |
|
|
|
|
|
x3 = x |
|
for layer in self.path3: |
|
x3 = layer(x3) |
|
outputs.append(x3) |
|
|
|
|
|
x4 = x |
|
for layer in self.path4: |
|
x4 = layer(x4) |
|
outputs.append(x4) |
|
|
|
|
|
x5 = x |
|
for layer in self.path5: |
|
x5 = layer(x5) |
|
outputs.append(x5) |
|
|
|
return outputs |
|
|
|
|
|
class DFEBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
|
self.act = nn.GELU() |
|
|
|
def forward(self, x): |
|
x_in = x |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.act(x) |
|
x = x + x_in |
|
return x |
|
|
|
|
|
class DiffusionFeatureExtractor(nn.Module): |
|
def __init__(self, in_channels=32): |
|
super().__init__() |
|
self.version = 1 |
|
num_blocks = 6 |
|
self.conv_in = nn.Conv2d(in_channels, 512, 1) |
|
self.blocks = nn.ModuleList([DFEBlock(512) for _ in range(num_blocks)]) |
|
self.conv_out = nn.Conv2d(512, 512, 1) |
|
|
|
def forward(self, x): |
|
x = self.conv_in(x) |
|
for block in self.blocks: |
|
x = block(x) |
|
x = self.conv_out(x) |
|
return x |
|
|
|
|
|
class DiffusionFeatureExtractor3(nn.Module): |
|
def __init__(self, device=torch.device("cuda"), dtype=torch.bfloat16, vae=None): |
|
super().__init__() |
|
self.version = 3 |
|
if vae is None: |
|
vae = AutoencoderTiny.from_pretrained( |
|
"madebyollin/taef1", torch_dtype=torch.bfloat16) |
|
self.vae = vae |
|
|
|
image_encoder_path = "google/siglip2-so400m-patch16-512" |
|
try: |
|
self.image_processor = SiglipImageProcessor.from_pretrained( |
|
image_encoder_path) |
|
except EnvironmentError: |
|
self.image_processor = SiglipImageProcessor() |
|
self.vision_encoder = SiglipVisionModel.from_pretrained( |
|
image_encoder_path, |
|
ignore_mismatched_sizes=True |
|
).to(device, dtype=dtype) |
|
|
|
self.lpips_model = lpips_model = lpips.LPIPS(net='vgg') |
|
self.lpips_model = lpips_model.to(device, dtype=torch.float32) |
|
self.losses = {} |
|
self.log_every = 100 |
|
self.step = 0 |
|
|
|
def get_siglip_features(self, tensors_0_1): |
|
dtype = torch.bfloat16 |
|
device = self.vae.device |
|
|
|
if 'height' in self.image_processor.size: |
|
size = self.image_processor.size['height'] |
|
else: |
|
size = self.image_processor.crop_size['height'] |
|
images = F.interpolate(tensors_0_1, size=(size, size), |
|
mode='bicubic', align_corners=False) |
|
|
|
mean = torch.tensor(self.image_processor.image_mean).to( |
|
device, dtype=dtype |
|
).detach() |
|
std = torch.tensor(self.image_processor.image_std).to( |
|
device, dtype=dtype |
|
).detach() |
|
|
|
clip_image = ( |
|
images - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) |
|
id_embeds = self.vision_encoder( |
|
clip_image, |
|
output_hidden_states=True, |
|
) |
|
|
|
last_hidden_state = id_embeds['last_hidden_state'] |
|
return last_hidden_state |
|
|
|
def get_lpips_features(self, tensors_0_1): |
|
device = self.vae.device |
|
tensors_n1p1 = (tensors_0_1 * 2) - 1 |
|
def get_lpips_features(img): |
|
in0_input = self.lpips_model.scaling_layer(img) |
|
outs0 = self.lpips_model.net.forward(in0_input) |
|
|
|
feats0 = {} |
|
|
|
feats_list = [] |
|
for kk in range(self.lpips_model.L): |
|
feats0[kk] = lpips.normalize_tensor(outs0[kk]) |
|
feats_list.append(feats0[kk]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return feats_list |
|
|
|
|
|
lpips_feat_list = [x for x in get_lpips_features( |
|
tensors_n1p1.to(device, dtype=torch.float32))] |
|
|
|
return lpips_feat_list |
|
|
|
|
|
def forward( |
|
self, |
|
noise, |
|
noise_pred, |
|
noisy_latents, |
|
timesteps, |
|
batch: DataLoaderBatchDTO, |
|
scheduler: CustomFlowMatchEulerDiscreteScheduler, |
|
|
|
lpips_weight=10.0, |
|
clip_weight=0.1, |
|
pixel_weight=0.1, |
|
model=None |
|
): |
|
dtype = torch.bfloat16 |
|
device = self.vae.device |
|
|
|
|
|
if model is not None and hasattr(model, 'get_stepped_pred'): |
|
stepped_latents = model.get_stepped_pred(noise_pred, noise) |
|
else: |
|
|
|
|
|
bs = noise_pred.shape[0] |
|
noise_pred_chunks = torch.chunk(noise_pred, bs) |
|
timestep_chunks = torch.chunk(timesteps, bs) |
|
noisy_latent_chunks = torch.chunk(noisy_latents, bs) |
|
stepped_chunks = [] |
|
for idx in range(bs): |
|
model_output = noise_pred_chunks[idx] |
|
timestep = timestep_chunks[idx] |
|
scheduler._step_index = None |
|
scheduler._init_step_index(timestep) |
|
sample = noisy_latent_chunks[idx].to(torch.float32) |
|
|
|
sigma = scheduler.sigmas[scheduler.step_index] |
|
sigma_next = scheduler.sigmas[-1] |
|
prev_sample = sample + (sigma_next - sigma) * model_output |
|
stepped_chunks.append(prev_sample) |
|
|
|
stepped_latents = torch.cat(stepped_chunks, dim=0) |
|
|
|
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) |
|
|
|
latents = ( |
|
latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] |
|
tensors_n1p1 = self.vae.decode(latents).sample |
|
|
|
pred_images = (tensors_n1p1 + 1) / 2 |
|
|
|
lpips_feat_list_pred = self.get_lpips_features(pred_images.float()) |
|
|
|
total_loss = 0 |
|
|
|
with torch.no_grad(): |
|
target_img = batch.tensor.to(device, dtype=dtype) |
|
|
|
target_img = (target_img + 1) / 2 |
|
lpips_feat_list_target = self.get_lpips_features(target_img.float()) |
|
if clip_weight > 0: |
|
target_clip_output = self.get_siglip_features(target_img).detach() |
|
if clip_weight > 0: |
|
pred_clip_output = self.get_siglip_features(pred_images) |
|
clip_loss = torch.nn.functional.mse_loss( |
|
pred_clip_output.float(), target_clip_output.float() |
|
) * clip_weight |
|
|
|
if 'clip_loss' not in self.losses: |
|
self.losses['clip_loss'] = clip_loss.item() |
|
else: |
|
self.losses['clip_loss'] += clip_loss.item() |
|
|
|
total_loss += clip_loss |
|
|
|
skip_lpips_layers = [] |
|
|
|
lpips_loss = 0 |
|
for idx, lpips_feat in enumerate(lpips_feat_list_pred): |
|
if idx in skip_lpips_layers: |
|
continue |
|
lpips_loss += torch.nn.functional.mse_loss( |
|
lpips_feat.float(), lpips_feat_list_target[idx].float() |
|
) * lpips_weight |
|
|
|
if f'lpips_loss_{idx}' not in self.losses: |
|
self.losses[f'lpips_loss_{idx}'] = lpips_loss.item() |
|
else: |
|
self.losses[f'lpips_loss_{idx}'] += lpips_loss.item() |
|
|
|
total_loss += lpips_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.step % self.log_every == 0 and self.step > 0: |
|
print(f"DFE losses:") |
|
for key in self.losses: |
|
self.losses[key] /= self.log_every |
|
|
|
print(f" - {key}: {self.losses[key]:.3e}") |
|
self.losses[key] = 0.0 |
|
|
|
|
|
self.step += 1 |
|
|
|
return total_loss |
|
|
|
|
|
def load_dfe(model_path, vae=None) -> DiffusionFeatureExtractor: |
|
if model_path == "v3": |
|
dfe = DiffusionFeatureExtractor3(vae=vae) |
|
dfe.eval() |
|
return dfe |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
|
if model_path.endswith('.safetensors'): |
|
state_dict = load_file(model_path) |
|
else: |
|
state_dict = torch.load(model_path, weights_only=True) |
|
if 'model_state_dict' in state_dict: |
|
state_dict = state_dict['model_state_dict'] |
|
|
|
if 'conv_in.weight' in state_dict: |
|
dfe = DiffusionFeatureExtractor() |
|
else: |
|
dfe = DiffusionFeatureExtractor2() |
|
|
|
dfe.load_state_dict(state_dict) |
|
dfe.eval() |
|
return dfe |