Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from pathlib import Path | |
| import warnings | |
| from PIL import Image | |
| from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline | |
| import lpips | |
| import torch | |
| from safetensors.torch import load_file, save_model | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2.functional as TF | |
| from einops import rearrange | |
| from qwenimage.datamodels import QwenConfig, QwenInputs | |
| from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam | |
| from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear | |
| from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast | |
| from qwenimage.loss import LossAccumulator | |
| from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions | |
| from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline | |
| from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel | |
| from qwenimage.optimization import simple_quantize_model | |
| from qwenimage.sampling import TimestepDistUtils | |
| from wandml import WandModel | |
| from wandml.core.logger import wand_logger | |
| from wandml.trainers.experiment_trainer import ExperimentTrainer | |
| class QwenImageFoundation(WandModel): | |
| SOURCE = "Qwen/Qwen-Image-Edit-2509" | |
| INPUT_MODEL = QwenInputs | |
| CACHE_DIR = "qwen_image_edit_2509" | |
| PIPELINE = QwenImageEditPlusPipeline | |
| serialize_modules = ["transformer"] | |
| def __init__(self, config:QwenConfig, device=None): | |
| super().__init__() | |
| self.config:QwenConfig = config | |
| self.dtype = torch.bfloat16 | |
| if device is None: | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = default_device | |
| else: | |
| self.device = device | |
| print(f"{self.device=}") | |
| pipe = self.PIPELINE.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2509", | |
| transformer=QwenImageTransformer2DModel.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2509", | |
| subfolder='transformer', | |
| torch_dtype=self.dtype, | |
| device_map=self.device | |
| ), | |
| torch_dtype=self.dtype, | |
| ) | |
| pipe = pipe.to(device=self.device, dtype=self.dtype) | |
| if config.load_multi_view_lora: | |
| pipe.load_lora_weights( | |
| "dx8152/Qwen-Edit-2509-Multiple-angles", | |
| weight_name="镜头转换.safetensors", adapter_name="angles" | |
| ) | |
| pipe.set_adapters(["angles"], adapter_weights=[1.]) | |
| pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.25) | |
| pipe.unload_lora_weights() | |
| self.pipe = pipe | |
| self.vae = self.pipe.vae | |
| self.transformer = self.pipe.transformer | |
| self.text_encoder = self.pipe.text_encoder | |
| self.scheduler = self.pipe.scheduler | |
| self.vae.to(self.dtype) | |
| self.vae.eval() | |
| self.vae.requires_grad_(False) | |
| self.text_encoder.eval() | |
| self.text_encoder.requires_grad_(False) | |
| self.text_encoder_device = None | |
| self.transformer.eval() | |
| self.transformer.requires_grad_(False) | |
| if self.config.gradient_checkpointing: | |
| self.transformer.enable_gradient_checkpointing() | |
| if self.config.vae_tiling: | |
| self.vae.enable_tiling( | |
| 576, | |
| 576, | |
| 512, | |
| 512 | |
| ) | |
| self.timestep_dist_utils = TimestepDistUtils( | |
| min_seq_len=self.scheduler.config.base_image_seq_len, | |
| max_seq_len=self.scheduler.config.max_image_seq_len, | |
| min_mu=self.scheduler.config.base_shift, | |
| max_mu=self.scheduler.config.max_shift, | |
| train_dist=self.config.train_dist, | |
| train_shift=self.config.train_shift, | |
| inference_dist=self.config.inference_dist, | |
| inference_shift=self.config.inference_shift, | |
| static_mu=self.config.static_mu, | |
| loss_weight_dist=self.config.loss_weight_dist, | |
| ) | |
| if self.config.quantize_text_encoder: | |
| quantize_text_encoder_int4wo_linear(self.text_encoder) | |
| if self.config.quantize_transformer: | |
| quantize_transformer_fp8darow_nolast(self.transformer) | |
| def load(self, load_path): | |
| if not isinstance(load_path, Path): | |
| load_path = Path(load_path) | |
| if not load_path.is_dir(): | |
| raise ValueError(f"Expected {load_path=} to be a directory") | |
| for module_name in self.serialize_modules: | |
| model_state_dict = load_file(load_path / f"{module_name}.safetensors") | |
| missing, unexpected = getattr(self, module_name).load_state_dict(model_state_dict, strict=False, assign=True) | |
| if missing: | |
| warnings.warn(f"{module_name} missing {missing}") | |
| if unexpected: | |
| warnings.warn(f"{module_name} unexpected {unexpected}") | |
| def save(self, save_path, skip=False): | |
| if skip: return | |
| if not isinstance(save_path, Path): | |
| save_path = Path(save_path) | |
| if not save_path.is_dir(): | |
| raise ValueError(f"Expected {save_path=} to be a directory") | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| for module_name in self.serialize_modules: | |
| save_model(getattr(self, module_name), save_path / f"{module_name}.safetensors") | |
| print(f"Saved {module_name} to {save_path}") | |
| def get_train_params(self): | |
| return [{"params": [p for p in self.transformer.parameters() if p.requires_grad]}] | |
| def pil_to_latents(self, images): | |
| image = self.pipe.image_processor.preprocess(images) | |
| h,w = image.shape[-2:] | |
| h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w) | |
| image = TF.resize(image, (h_r, w_r)) | |
| print("pil_to_latents.image") | |
| texam(image) | |
| image = image.unsqueeze(2) # N, C, F=1, H, W | |
| image = image.to(device=self.device, dtype=self.dtype) | |
| latents = self.pipe.vae.encode(image).latent_dist.mode() # argmax | |
| latents_mean = ( | |
| torch.tensor(self.pipe.vae.config.latents_mean) | |
| .view(1, self.pipe.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents_std = ( | |
| torch.tensor(self.pipe.vae.config.latents_std) | |
| .view(1, self.pipe.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents = (latents - latents_mean) / latents_std | |
| latents = latents.squeeze(2) | |
| print("pil_to_latents.latents") | |
| texam(latents) | |
| return latents.to(dtype=self.dtype) | |
| def latents_to_pil(self, latents, h=None, w=None, with_grad=False): | |
| if not with_grad: | |
| latents = latents.clone().detach() | |
| if latents.dim() == 3: # 1d latent | |
| if h is None or w is None: | |
| raise ValueError(f"auto unpack needs h,w, got {h=}, {w=}") | |
| latents = self.unpack_latents(latents, h=h, w=w) | |
| latents = latents.unsqueeze(2) | |
| latents = latents.to(self.dtype) | |
| latents_mean = ( | |
| torch.tensor(self.pipe.vae.config.latents_mean) | |
| .view(1, self.pipe.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents_std = ( | |
| torch.tensor(self.pipe.vae.config.latents_std) | |
| .view(1, self.pipe.vae.config.z_dim, 1, 1, 1) | |
| .to(latents.device, latents.dtype) | |
| ) | |
| latents = latents * latents_std + latents_mean | |
| latents = latents.to(device=self.device, dtype=self.dtype) | |
| image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1 | |
| if with_grad: | |
| texam(image, "latents_to_pil.image") | |
| return image | |
| image = self.pipe.image_processor.postprocess(image) | |
| return image | |
| def pack_latents(latents): | |
| packed = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
| return packed | |
| def unpack_latents(packed, h, w): | |
| latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w) | |
| return latents | |
| def offload_text_encoder(self, device=str|torch.device): | |
| if self.text_encoder_device == device: | |
| return | |
| print(f"Moving text encoder to {device}") | |
| self.text_encoder_device = device | |
| self.text_encoder.to(device) | |
| if device == "cpu" or device == torch.device("cpu"): | |
| print_gpu_memory(clear_mem="pre") | |
| def preprocess_batch(self, batch): | |
| prompts = batch["text"] | |
| references = batch["reference"] | |
| h,w = references.shape[-2:] | |
| h_r, w_r = calculate_dimensions(CONDITION_IMAGE_SIZE, h/w) | |
| references = TF.resize(references, (h_r, w_r)) | |
| print("preprocess_batch.references") | |
| texam(references) | |
| self.offload_text_encoder("cuda") | |
| with torch.no_grad(): | |
| prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt( | |
| prompts, | |
| references.mul(255), # scaled to RGB | |
| device="cuda", | |
| max_sequence_length = self.config.train_max_sequence_length, | |
| ) | |
| prompt_embeds = prompt_embeds.cpu().clone().detach() | |
| prompt_embeds_mask = prompt_embeds_mask.cpu().clone().detach() | |
| batch["prompt_embeds"] = (prompt_embeds, prompt_embeds_mask) | |
| batch["reference"] = batch["reference"].cpu() | |
| batch["image"] = batch["image"].cpu() | |
| return batch | |
| def single_step(self, batch) -> torch.Tensor: | |
| self.offload_text_encoder("cpu") | |
| if "prompt_embeds" not in batch: | |
| batch = self.preprocess_batch(batch) | |
| prompt_embeds, prompt_embeds_mask = batch["prompt_embeds"] | |
| prompt_embeds = prompt_embeds.to(device=self.device) | |
| prompt_embeds_mask = prompt_embeds_mask.to(device=self.device) | |
| images = batch["image"].to(device=self.device, dtype=self.dtype) | |
| x_0 = self.pil_to_latents(images).to(device=self.device, dtype=self.dtype) | |
| x_1 = torch.randn_like(x_0).to(device=self.device, dtype=self.dtype) | |
| seq_len = self.timestep_dist_utils.get_seq_len(x_0) | |
| batch_size = x_0.shape[0] | |
| t = self.timestep_dist_utils.get_train_t([batch_size], seq_len=seq_len).to(device=self.device, dtype=self.dtype) | |
| x_t = (1.0 - t) * x_0 + t * x_1 | |
| x_t_1d = self.pack_latents(x_t) | |
| references = batch["reference"].to(device=self.device, dtype=self.dtype) | |
| print("references") | |
| texam(references) | |
| assert references.shape[0] == 1 | |
| refs = self.pil_to_latents(references).to(device=self.device, dtype=self.dtype) | |
| refs_1d = self.pack_latents(refs) | |
| print("refs refs_1d") | |
| texam(refs) | |
| texam(refs_1d) | |
| inp_1d = torch.cat([x_t_1d, refs_1d], dim=1) | |
| print("inp_1d") | |
| texam(inp_1d) | |
| l_height, l_width = x_0.shape[-2:] | |
| ref_height, ref_width = refs.shape[-2:] | |
| img_shapes = [ | |
| [ | |
| (1, l_height // 2, l_width // 2), | |
| (1, ref_height // 2, ref_width // 2), | |
| ] | |
| ] * batch_size | |
| txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() | |
| image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=x_0.device) | |
| v_pred_1d = self.transformer( | |
| hidden_states=inp_1d, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_hidden_states_mask=prompt_embeds_mask, | |
| timestep=t, | |
| image_rotary_emb=image_rotary_emb, | |
| return_dict=False, | |
| )[0] | |
| v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)] | |
| v_pred_2d = self.unpack_latents(v_pred_1d, h=l_height//2, w=l_width//2) | |
| v_gt_2d = x_1 - x_0 | |
| if self.config.loss_weight_dist is not None: | |
| loss = F.mse_loss(v_pred_2d, v_gt_2d, reduction="none").mean(dim=[1,2,3]) | |
| weights = self.timestep_dist_utils.get_loss_weighting(t) | |
| loss = torch.mean(loss * weights) | |
| else: | |
| loss = F.mse_loss(v_pred_2d, v_gt_2d, reduction="mean") | |
| return loss | |
| def base_pipe(self, inputs: QwenInputs) -> list[Image]: | |
| print(inputs) | |
| self.offload_text_encoder("cuda") | |
| if inputs.vae_image_override is None: | |
| inputs.vae_image_override = self.config.vae_image_size | |
| if inputs.latent_size_override is None: | |
| inputs.latent_size_override = self.config.vae_image_size | |
| return self.pipe(**inputs.model_dump()).images | |
| class QwenImageFoundationSaveInterm(QwenImageFoundation): | |
| PIPELINE = QwenImageEditSaveIntermPipeline | |
| class QwenImageRegressionFoundation(QwenImageFoundation): | |
| def __init__(self, config:QwenConfig, device=None): | |
| super().__init__(config, device=device) | |
| self.lpips_fn = lpips.LPIPS(net='vgg').to(device=self.device) | |
| def preprocess_batch(self, batch): | |
| return batch | |
| def single_step(self, batch) -> torch.Tensor: | |
| self.offload_text_encoder("cpu") | |
| out_dict = batch["data"] | |
| assert len(out_dict) == 1 | |
| out_dict = out_dict[0] | |
| prompt_embeds = out_dict["prompt_embeds"] | |
| prompt_embeds_mask = out_dict["prompt_embeds_mask"] | |
| prompt_embeds = prompt_embeds.to(device=self.device, dtype=self.dtype) | |
| prompt_embeds_mask = prompt_embeds_mask.to(device=self.device, dtype=self.dtype) | |
| h_f16 = out_dict["height"] // 16 | |
| w_f16 = out_dict["width"] // 16 | |
| refs_1d = out_dict["image_latents"].to(device=self.device, dtype=self.dtype) | |
| t = out_dict["t"].to(device=self.device, dtype=self.dtype) | |
| x_0_1d = out_dict["output"].to(device=self.device, dtype=self.dtype) | |
| x_t_1d = out_dict["latents_start"].to(device=self.device, dtype=self.dtype) | |
| v_neg_1d = out_dict["noise_pred"].to(device=self.device, dtype=self.dtype) | |
| v_gt_1d = (x_t_1d - x_0_1d) / t | |
| inp_1d = torch.cat([x_t_1d, refs_1d], dim=1) | |
| print("inp_1d") | |
| texam(inp_1d) | |
| img_shapes = out_dict["img_shapes"] | |
| txt_seq_lens = out_dict["txt_seq_lens"] | |
| image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=self.device) | |
| v_pred_1d = self.transformer( | |
| hidden_states=inp_1d, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_hidden_states_mask=prompt_embeds_mask, | |
| timestep=t, | |
| image_rotary_emb=image_rotary_emb, | |
| return_dict=False, | |
| )[0] | |
| v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)] | |
| split = batch["split"] | |
| step = batch["step"] | |
| if split == "train": | |
| loss_terms = self.config.train_loss_terms | |
| elif split == "validation": | |
| loss_terms = self.config.validation_loss_terms | |
| loss_accumulator = LossAccumulator( | |
| terms=loss_terms.model_dump(), | |
| step=step, | |
| split=split, | |
| term_groups={"pixel":loss_terms.pixel_terms} | |
| ) | |
| if loss_accumulator.has("mse"): | |
| if self.config.loss_weight_dist is not None: | |
| mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="none").mean(dim=[1,2,3]) | |
| weights = self.timestep_dist_utils.get_loss_weighting(t) | |
| mse_loss = torch.mean(mse_loss * weights) | |
| else: | |
| mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="mean") | |
| loss_accumulator.accum("mse", mse_loss) | |
| if loss_accumulator.has("triplet"): | |
| # 1d, B,L,C | |
| margin = loss_terms.triplet_margin | |
| triplet_min_abs_diff = loss_terms.triplet_min_abs_diff | |
| print(f"{triplet_min_abs_diff=}") | |
| v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2) | |
| zero_weight = torch.zeros_like(v_gt_neg_diff) | |
| v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight) | |
| ones = torch.ones_like(v_gt_neg_diff) | |
| filtered_nums = torch.sum(torch.where(v_gt_neg_diff > triplet_min_abs_diff, ones, zero_weight)) | |
| wand_logger.log({ | |
| "filtered_nums": filtered_nums, | |
| }, commit=False) | |
| diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2) | |
| diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2) | |
| per_tok_diff = (diffv_gt_pred - diffv_neg_pred).sum(dim=2) | |
| triplet_loss = torch.mean(F.relu((per_tok_diff + margin) * v_weight)) | |
| ones = torch.ones_like(per_tok_diff) | |
| zeros = torch.zeros_like(per_tok_diff) | |
| loss_nonzero_nums = torch.sum(torch.where(((per_tok_diff + margin) * v_weight)>0, ones, zeros)) | |
| wand_logger.log({ | |
| "loss_nonzero_nums": loss_nonzero_nums, | |
| }, commit=False) | |
| loss_accumulator.accum("triplet", triplet_loss) | |
| texam(v_gt_neg_diff, "v_gt_neg_diff") | |
| texam(v_weight, "v_weight") | |
| texam(diffv_gt_pred, "diffv_gt_pred") | |
| texam(diffv_neg_pred, "diffv_neg_pred") | |
| texam(per_tok_diff, "per_tok_diff") | |
| if loss_accumulator.has("negative_mse"): | |
| neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean") | |
| loss_accumulator.accum("negative_mse", neg_mse_loss) | |
| if loss_accumulator.has("distribution_matching"): | |
| dm_v = (v_pred_1d - v_neg_1d + v_gt_1d).detach() | |
| dm_mse = F.mse_loss(v_pred_1d, dm_v, reduction="mean") | |
| loss_accumulator.accum("distribution_matching", dm_mse) | |
| if loss_accumulator.has("negative_exponential"): | |
| raise NotImplementedError() | |
| if loss_accumulator.has_group("pixel"): | |
| x_0_pred = x_t_1d - t * v_pred_1d | |
| with torch.no_grad(): | |
| pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach() | |
| pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True) | |
| if loss_accumulator.has("pixel_lpips"): | |
| lpips_loss = self.lpips_fn(pixel_values_x0_gt, pixel_values_x0_pred) | |
| texam(lpips_loss, "lpips_loss") | |
| lpips_loss = lpips_loss.mean() | |
| texam(lpips_loss, "lpips_loss") | |
| loss_accumulator.accum("pixel_lpips", lpips_loss) | |
| if loss_accumulator.has("pixel_mse"): | |
| pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean") | |
| loss_accumulator.accum("pixel_mse", pixel_mse_loss) | |
| if loss_accumulator.has("pixel_triplet"): | |
| raise NotImplementedError() | |
| loss_accumulator.accum("pixel_triplet", pixel_triplet_loss) | |
| if loss_accumulator.has("pixel_distribution_matching"): | |
| raise NotImplementedError() | |
| loss_accumulator.accum("pixel_distribution_matching", pixel_distribution_matching_loss) | |
| if loss_accumulator.has("adversarial"): | |
| raise NotImplementedError() | |
| loss = loss_accumulator.total | |
| logs = loss_accumulator.logs() | |
| wand_logger.log(logs, step=step, commit=False) | |
| wand_logger.log({ | |
| "t": t.float().cpu().item() | |
| }, step=step, commit=False) | |
| if self.should_log_training(step): | |
| self.log_single_step_images( | |
| h_f16, | |
| w_f16, | |
| t, | |
| x_0_1d, | |
| x_t_1d, | |
| v_gt_1d, | |
| v_neg_1d, | |
| v_pred_1d, | |
| visualize_velocities=False, | |
| ) | |
| return loss | |
| def should_log_training(self, step) -> bool: | |
| return ( | |
| self.training # don't log when validating | |
| and ExperimentTrainer._is_step_trigger(step, self.config.log_batch_steps) | |
| ) | |
| def log_single_step_images( | |
| self, | |
| h_f16, | |
| w_f16, | |
| t, | |
| x_0_1d, | |
| x_t_1d, | |
| v_gt_1d, | |
| v_neg_1d, | |
| v_pred_1d, | |
| visualize_velocities=False, | |
| ): | |
| t_float = t.float().cpu().item() | |
| x_0_pred = x_t_1d - t * v_pred_1d | |
| x_0_neg = x_t_1d - t * v_neg_1d | |
| x_0_recon = x_t_1d - t * v_gt_1d | |
| log_pils = { | |
| f"x_{t_float}_1d": self.latents_to_pil(x_t_1d, h=h_f16, w=w_f16), | |
| "x_0": self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16), | |
| "x_0_recon": self.latents_to_pil(x_0_recon, h=h_f16, w=w_f16), | |
| "x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16), | |
| "x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16), | |
| } | |
| if visualize_velocities: # naively visualizing through vae (works with flux) | |
| log_pils.update({ | |
| "v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16), | |
| "v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16), | |
| "v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16), | |
| }) | |
| # create gt-neg difference maps | |
| v_pred_2d = self.unpack_latents(v_pred_1d, h_f16, w_f16) | |
| v_gt_2d = self.unpack_latents(v_gt_1d, h_f16, w_f16) | |
| v_neg_2d = self.unpack_latents(v_neg_1d, h_f16, w_f16) | |
| gt_neg_diff_map_2d = (v_gt_2d - v_neg_2d).pow(2).mean(dim=1, keepdim=True) | |
| gt_pred_diff_map_2d = (v_gt_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True) | |
| neg_pred_diff_map_2d = (v_neg_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True) | |
| diff_max = torch.max(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d])) | |
| diff_min = torch.min(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d])) | |
| print(f"{diff_min}, {diff_max}") | |
| # norms to 0-1 | |
| diff_span = diff_max - diff_min | |
| gt_neg_diff_map_2d = (gt_neg_diff_map_2d - diff_min) / diff_span | |
| gt_pred_diff_map_2d = (gt_pred_diff_map_2d - diff_min) / diff_span | |
| neg_pred_diff_map_2d = (neg_pred_diff_map_2d - diff_min) / diff_span | |
| log_pils.update({ | |
| "gt-neg":gt_neg_diff_map_2d.float().cpu(), | |
| "gt-pred":gt_pred_diff_map_2d.float().cpu(), | |
| "neg-pred":neg_pred_diff_map_2d.float().cpu(), | |
| }) | |
| wand_logger.log({ | |
| "train_images": log_pils, | |
| }, commit=False) | |
| def base_pipe(self, inputs: QwenInputs) -> list[Image]: | |
| inputs.num_inference_steps = self.config.regression_base_pipe_steps | |
| return super().base_pipe(inputs) | |