Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import torch | |
| import copy | |
| import timm | |
| import torchvision.transforms.v2.functional | |
| from torch.nn import Parameter | |
| from src.utils.no_grad import no_grad | |
| from typing import Callable, Iterator, Tuple | |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torchvision.transforms import Normalize | |
| from src.diffusion.base.training import * | |
| from src.diffusion.base.scheduling import BaseScheduler | |
| from src.diffusion.base.sampling import BaseSampler | |
| def inverse_sigma(alpha, sigma): | |
| return 1/sigma**2 | |
| def snr(alpha, sigma): | |
| return alpha/sigma | |
| def minsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, min=threshold) | |
| def maxsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, max=threshold) | |
| def constant(alpha, sigma): | |
| return 1 | |
| from PIL import Image | |
| import numpy as np | |
| def time_shift_fn(t, timeshift=1.0): | |
| return t/(t+(1-t)*timeshift) | |
| def random_crop(images, resize, crop_size): | |
| images = torchvision.transforms.v2.functional.resize(images, size=resize, antialias=True) | |
| h, w = crop_size | |
| h0 = random.randint(0, images.shape[2]-h) | |
| w0 = random.randint(0, images.shape[3]-w) | |
| return images[:, :, h0:h0+h, w0:w0+w] | |
| # class EulerSolver: | |
| # def __init__( | |
| # self, | |
| # num_steps: int, | |
| # *args, | |
| # **kwargs | |
| # ): | |
| # super().__init__(*args, **kwargs) | |
| # self.num_steps = num_steps | |
| # self.timesteps = torch.linspace(0.0, 1, self.num_steps+1, dtype=torch.float32) | |
| # | |
| # def __call__(self, net, noise, timeshift, condition): | |
| # steps = time_shift_fn(self.timesteps[:, None], timeshift[None, :]).to(noise.device, noise.dtype) | |
| # x = noise | |
| # trajs = [x, ] | |
| # for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): | |
| # dt = t_next - t_cur | |
| # v = net(x, t_cur, condition) | |
| # x = x + v*dt[:, None, None, None] | |
| # x = x.to(noise.dtype) | |
| # trajs.append(x) | |
| # return trajs | |
| # | |
| # class NeuralSolver(nn.Module): | |
| # def __init__( | |
| # self, | |
| # num_steps: int, | |
| # *args, | |
| # **kwargs | |
| # ): | |
| # super().__init__(*args, **kwargs) | |
| # self.num_steps = num_steps | |
| # self.timedeltas = torch.nn.Parameter(torch.ones((num_steps))/num_steps, requires_grad=True) | |
| # self.coeffs = torch.nn.Parameter(torch.zeros((num_steps, num_steps)), requires_grad=True) | |
| # # self.golden_noise = torch.nn.Parameter(torch.randn((1, 3, 1024, 1024))*0.01, requires_grad=True) | |
| # | |
| # def forward(self, net, noise, timeshift, condition): | |
| # batch_size, c, height, width = noise.shape | |
| # # golden_noise = torch.nn.functional.interpolate(self.golden_noise, size=(height, width), mode='bicubic', align_corners=False) | |
| # x = noise # + golden_noise.repeat(batch_size, 1, 1, 1) | |
| # x_trajs = [x, ] | |
| # v_trajs = [] | |
| # dts = self.timedeltas.softmax(dim=0) | |
| # print(dts) | |
| # coeffs = self.coeffs | |
| # t_cur = torch.zeros((batch_size,), dtype=noise.dtype, device=noise.device) | |
| # for i, dt in enumerate(dts): | |
| # pred_v = net(x, t_cur, condition) | |
| # v = torch.zeros_like(pred_v) | |
| # v_trajs.append(pred_v) | |
| # acc_coeffs = 0.0 | |
| # for j in range(i): | |
| # acc_coeffs = acc_coeffs + coeffs[i, j] | |
| # v = v + coeffs[i, j]*v_trajs[j] | |
| # v = v + (1-acc_coeffs)*v_trajs[i] | |
| # x = x + v*dt | |
| # x = x.to(noise.dtype) | |
| # x_trajs.append(x) | |
| # t_cur = t_cur + dt | |
| # return x_trajs | |
| import re | |
| import os | |
| import unicodedata | |
| def clean_filename(s): | |
| # 去除首尾空格和点号 | |
| s = s.strip().strip('.') | |
| # 转换 Unicode 字符为 ASCII 形式 | |
| s = unicodedata.normalize('NFKD', s).encode('ASCII', 'ignore').decode('ASCII') | |
| illegal_chars = r'[/]' | |
| reserved_names = set() | |
| # 替换非法字符为下划线 | |
| s = re.sub(illegal_chars, '_', s) | |
| # 合并连续的下划线 | |
| s = re.sub(r'_{2,}', '_', s) | |
| # 转换为小写 | |
| s = s.lower() | |
| # 检查是否为保留文件名 | |
| if s.upper() in reserved_names: | |
| s = s + '_' | |
| # 限制文件名长度 | |
| max_length = 200 | |
| s = s[:max_length] | |
| if not s: | |
| return 'untitled' | |
| return s | |
| def prompt_augment(prompts, random_prompts, replace_prob=0.5, front_append_prob=0.5, back_append_prob=0.5, delete_prob=0.5,): | |
| random_prompts = random.choices(random_prompts, k=len(prompts)) | |
| new_prompts = [] | |
| for prompt, random_prompt in zip(prompts, random_prompts): | |
| if random.random() < replace_prob: | |
| new_prompt = random_prompt | |
| else: | |
| new_prompt = prompt | |
| if random.random() < front_append_prob: | |
| new_prompt = random_prompt + ", " + new_prompt | |
| if random.random() < back_append_prob: | |
| new_prompt = new_prompt + ", " + random_prompt | |
| if random.random() < delete_prob: | |
| new_length = random.randint(1, len(new_prompt.split(","))) | |
| new_prompt = ", ".join(new_prompt.split(",")[:new_length]) | |
| new_prompts.append(new_prompt) | |
| return new_prompts | |
| class AdvODETrainer(BaseTrainer): | |
| def __init__( | |
| self, | |
| scheduler: BaseScheduler, | |
| loss_weight_fn:Callable=constant, | |
| adv_loss_weight: float=0.5, | |
| gan_loss_weight: float=0.5, | |
| im_encoder:nn.Module=None, | |
| mm_encoder:nn.Module=None, | |
| adv_head:nn.Module=None, | |
| random_crop_size=448, | |
| max_image_size=512, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.scheduler = scheduler | |
| self.loss_weight_fn = loss_weight_fn | |
| self.adv_loss_weight = adv_loss_weight | |
| self.gan_loss_weight = gan_loss_weight | |
| self.im_encoder = im_encoder | |
| self.mm_encoder = mm_encoder | |
| self.adv_head = adv_head | |
| self.real_buffer = [] | |
| self.fake_buffer = [] | |
| self.random_crop_size = random_crop_size | |
| self.max_image_size = max_image_size | |
| no_grad(self.im_encoder) | |
| no_grad(self.mm_encoder) | |
| self.random_prompts = ["hahahaha", ] | |
| self.saved_filenames = [] | |
| def preproprocess(self, x, condition, uncondition, metadata): | |
| self.uncondition = uncondition | |
| return super().preproprocess(x, condition, uncondition, metadata) | |
| def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): | |
| batch_size, c, height, width = x.shape | |
| noise = torch.randn_like(x) | |
| _, trajs = solver(net, noise, y, self.uncondition, return_x_trajs=True, return_v_trajs=False) | |
| with torch.no_grad(): | |
| _, ref_trajs = solver(ema_net, noise, y, self.uncondition, return_x_trajs=True, return_v_trajs=False) | |
| fake_x0 = (trajs[-1]+1)/2 | |
| fake_x0 = fake_x0.clamp(0, 1) | |
| prompt = metadata["prompt"] | |
| self.random_prompts.extend(prompt) | |
| self.random_prompts = self.random_prompts[-50:] | |
| filename = clean_filename(prompt[0])+".png" | |
| Image.fromarray((fake_x0[0].permute(1, 2, 0).detach().cpu().float() * 255).to(torch.uint8).numpy()).save(f'{filename}') | |
| self.saved_filenames.append(filename) | |
| if len(self.saved_filenames) > 100: | |
| os.remove(self.saved_filenames[0]) | |
| self.saved_filenames.pop(0) | |
| real_x0 = metadata["raw_image"] | |
| fake_x0 = random_crop(fake_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) | |
| real_x0 = random_crop(real_x0, resize=self.max_image_size, crop_size=(self.random_crop_size, self.random_crop_size)) | |
| fake_im_features = self.im_encoder(fake_x0, resize=False) | |
| fake_mm_features = self.mm_encoder(fake_x0, prompt, resize=True) | |
| fake_im_features_detach = fake_im_features.detach() | |
| fake_mm_features_detach = fake_mm_features.detach() | |
| with torch.no_grad(): | |
| real_im_features = self.im_encoder(real_x0, resize=False) | |
| real_mm_features = self.mm_encoder(real_x0, prompt, resize=True) | |
| not_match_prompt = prompt_augment(prompt, self.random_prompts)#random.choices(self.random_prompts, k=batch_size) | |
| real_not_match_mm_features = self.mm_encoder(real_x0, not_match_prompt, resize=True) | |
| self.real_buffer.append((real_im_features, real_mm_features)) | |
| self.fake_buffer.append((fake_im_features_detach, fake_mm_features_detach)) | |
| self.fake_buffer.append((real_im_features, real_not_match_mm_features)) | |
| while len(self.real_buffer) > 10: | |
| self.real_buffer.pop(0) | |
| while len(self.fake_buffer) > 10: | |
| self.fake_buffer.pop(0) | |
| real_features_gan = torch.cat([x[0] for x in self.real_buffer], dim=0) | |
| real_conditions_gan = torch.cat([x[1] for x in self.real_buffer], dim=0) | |
| fake_features_gan = torch.cat([x[0] for x in self.fake_buffer], dim=0) | |
| fake_conditions_gan = torch.cat([x[1] for x in self.fake_buffer], dim=0) | |
| real_score_gan = self.adv_head(real_features_gan, real_conditions_gan) | |
| fake_score_gan = self.adv_head(fake_features_gan, fake_conditions_gan) | |
| fake_score_adv = self.adv_head(fake_im_features, fake_mm_features) | |
| fake_score_detach_adv = self.adv_head(fake_im_features_detach, fake_mm_features_detach) | |
| loss_gan = -torch.log(1 - fake_score_gan).mean() - torch.log(real_score_gan).mean() | |
| acc_real = (real_score_gan > 0.5).float() | |
| acc_fake = (fake_score_gan < 0.5).float() | |
| loss_adv = -torch.log(fake_score_adv) | |
| loss_adv_hack = torch.log(fake_score_detach_adv) | |
| trajs_loss = 0.0 | |
| for x_t, ref_x_t in zip(trajs, ref_trajs): | |
| trajs_loss = trajs_loss + torch.abs(x_t - ref_x_t).mean() | |
| trajs_loss = trajs_loss / len(trajs) | |
| out = dict( | |
| trajs_loss=trajs_loss.mean(), | |
| adv_loss=loss_adv.mean(), | |
| gan_loss=loss_gan.mean(), | |
| acc_real=acc_real.mean(), | |
| acc_fake=acc_fake.mean(), | |
| loss=trajs_loss.mean() + self.adv_loss_weight*(loss_adv.mean() + loss_adv_hack.mean())+self.gan_loss_weight*loss_gan.mean(), | |
| ) | |
| return out | |
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
| self.adv_head.state_dict( | |
| destination=destination, | |
| prefix=prefix + "adv_head.", | |
| keep_vars=keep_vars) | |