| import gc |
| import os |
| from typing import List, Literal, Optional, Tuple |
|
|
| import imageio |
| import lightning as L |
| import torch |
| from jaxtyping import Float |
| from torchvision.transforms import Resize |
| from tqdm import tqdm |
|
|
| from src.loss import AbstractLoss |
| from src.loss.vector_swd import VectorSWDLoss |
| from src.utils.asc_cdl import asc_cdl_forward, save_asc_cdl |
| from src.utils.color_space import rgb_to_lab |
| from src.utils.image import from_torch, read_img, to_torch, write_img |
|
|
|
|
| class CDL(torch.nn.Module): |
| def __init__(self, batch_size: int): |
| super().__init__() |
| self.cdl_slope = torch.nn.Parameter(torch.ones(batch_size, 3)) |
| self.cdl_offset = torch.nn.Parameter(torch.zeros(batch_size, 3)) |
| self.cdl_power = torch.nn.Parameter(torch.ones(batch_size, 3)) |
| self.cdl_saturation = torch.nn.Parameter(torch.ones(batch_size)) |
|
|
| def forward( |
| self, x: Float[torch.Tensor, "*B C H W"] |
| ) -> Float[torch.Tensor, "*B C H W"]: |
| return asc_cdl_forward( |
| x, self.cdl_slope, self.cdl_offset, self.cdl_power, self.cdl_saturation |
| ) |
|
|
| def to_cdl_xml(self) -> str: |
| ret = [] |
| for b in range(self.cdl_slope.shape[0]): |
| ret.append( |
| save_asc_cdl( |
| { |
| "slope": self.cdl_slope[b], |
| "offset": self.cdl_offset[b], |
| "power": self.cdl_power[b], |
| "saturation": self.cdl_saturation[b], |
| }, |
| None, |
| ) |
| ) |
| return ret |
|
|
| def save(self, path: str): |
| for b in range(self.cdl_slope.shape[0]): |
| save_asc_cdl( |
| { |
| "slope": self.cdl_slope[b], |
| "offset": self.cdl_offset[b], |
| "power": self.cdl_power[b], |
| "saturation": self.cdl_saturation[b], |
| }, |
| os.path.join(path, f"cdl_{b}.xml"), |
| ) |
|
|
|
|
| def train( |
| criteria: AbstractLoss, |
| source_img: Float[torch.Tensor, "B C H W"], |
| target_img: Float[torch.Tensor, "B C H W"], |
| num_steps: int, |
| lr: float, |
| match_resolution: int, |
| silent: bool = False, |
| write_video_animation_path: Optional[str] = None, |
| ) -> Tuple[Float[torch.Tensor, "*B C H W"], CDL, List[float]]: |
| criteria = criteria.cuda() |
|
|
| source_max_res = Resize(match_resolution, antialias=True)(source_img) |
| target_max_res = Resize(match_resolution, antialias=True)(target_img) |
|
|
| target_cielab = ( |
| rgb_to_lab(target_max_res).cuda().permute(0, 3, 1, 2) |
| .permute(0, 2, 3, 1) |
| .contiguous() |
| ) |
|
|
| source_max_res = source_max_res.cuda() |
| source_img = source_img.cuda() |
|
|
| batch_size = source_img.shape[0] |
| cdl = CDL(batch_size).cuda() |
|
|
| optim = torch.optim.Adam(cdl.parameters(), lr=lr) |
|
|
| lossses = [] |
| for i in tqdm(range(num_steps), disable=silent): |
| optim.zero_grad(set_to_none=True) |
|
|
| cdl_source = cdl(source_max_res) |
| source_cielab = ( |
| rgb_to_lab(cdl_source.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous() |
| ) |
|
|
| loss = criteria( |
| source_cielab.view(source_cielab.shape[0], source_cielab.shape[1], -1), |
| target_cielab.view(target_cielab.shape[0], target_cielab.shape[1], -1), |
| i, |
| ) |
|
|
| loss.backward() |
| optim.step() |
|
|
| lossses.append(loss.item()) |
|
|
| if write_video_animation_path is not None: |
| write_img( |
| os.path.join(write_video_animation_path, f"{i:05d}.jpg"), |
| from_torch(cdl(source_img).squeeze(0) * 2 - 1), |
| ) |
|
|
| source_full_res_cdl = cdl(source_img) |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return source_full_res_cdl, cdl, lossses |
|
|
|
|
| def run( |
| save_dir: str, |
| source_img: List[str], |
| target_img: List[str], |
| matching_resolution: int, |
| precision: Literal["32-true", "16-mixed"] = "16-mixed", |
| num_projections: int = 64, |
| lr: float = 0.01, |
| steps: int = 300, |
| use_ucv: bool = False, |
| use_lcv: bool = False, |
| distance: Literal["l1", "l2"] = "l1", |
| refresh_projections_every_n_steps: int = 1, |
| num_new_candidates: int = 32, |
| sampling_mode: Literal["gaussian", "qmc"] = "gaussian", |
| write_video: bool = False, |
| **kwargs, |
| ): |
| source_imgs = torch.stack( |
| [to_torch(read_img(s)) * 0.5 + 0.5 for s in source_img], dim=0 |
| ) |
| target_imgs = torch.stack( |
| [to_torch(read_img(t)) * 0.5 + 0.5 for t in target_img], dim=0 |
| ) |
|
|
| criteria = VectorSWDLoss( |
| num_proj=num_projections, |
| distance=distance, |
| use_ucv=use_ucv, |
| use_lcv=use_lcv, |
| refresh_projections_every_n_steps=refresh_projections_every_n_steps, |
| num_new_candidates=num_new_candidates, |
| sampling_mode=sampling_mode, |
| ) |
|
|
| os.makedirs(save_dir, exist_ok=True) |
| animation_dir = os.path.join(save_dir, "animation") |
|
|
| if write_video: |
| os.makedirs(animation_dir, exist_ok=True) |
|
|
| source_full_res_cdl, cdl, lossses = train( |
| criteria, |
| source_imgs, |
| target_imgs, |
| steps, |
| lr, |
| matching_resolution, |
| write_video_animation_path=animation_dir if write_video else None, |
| ) |
|
|
| cdl.save(save_dir) |
|
|
| for i, img in enumerate(source_full_res_cdl): |
| write_img( |
| os.path.join(save_dir, f"color_matched_{i}.png"), |
| from_torch(img * 2 - 1), |
| ) |
|
|
| if write_video: |
| |
| image_files = [f for f in os.listdir(animation_dir) if f.endswith(".jpg")] |
| image_files.sort( |
| key=lambda x: int(x.split(".")[0]) |
| ) |
|
|
| |
| with imageio.get_writer( |
| os.path.join(save_dir, "animation.mp4"), fps=30, codec="libx264" |
| ) as writer: |
| for image_file in image_files: |
| image = imageio.imread(os.path.join(animation_dir, image_file)) |
| writer.append_data(image) |
|
|
| return source_full_res_cdl, cdl, lossses |
|
|