diff --git a/app.py b/app.py index 05736cdd3e0b6fa96d47356442dc0be3b7de9262..243078fee26fc7708d998045dbb18e20ec0f5e81 100644 --- a/app.py +++ b/app.py @@ -1,161 +1,74 @@ -import spaces -import os -import requests -import yaml -import torch import gradio as gr -from PIL import Image -import sys -sys.path.append(os.path.abspath('./')) -from inference.utils import * -from core.utils import load_or_fail -from train import WurstCoreB -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from train import WurstCore_t2i as WurstCoreC -import torch.nn.functional as F -from core.utils import load_or_fail -import numpy as np -import random -import math -from einops import rearrange - -def download_file(url, folder_path, filename): - if not os.path.exists(folder_path): - os.makedirs(folder_path) - file_path = os.path.join(folder_path, filename) - - if os.path.isfile(file_path): - print(f"File already exists: {file_path}") - else: - response = requests.get(url, stream=True) - if response.status_code == 200: - with open(file_path, 'wb') as file: - for chunk in response.iter_content(chunk_size=1024): - file.write(chunk) - print(f"File successfully downloaded and saved: {file_path}") - else: - print(f"Error downloading the file. Status code: {response.status_code}") - -def download_models(): - models = { - "STABLEWURST_A": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors?download=true", "models/", "stage_a.safetensors"), - "STABLEWURST_PREVIEWER": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors?download=true", "models/", "previewer.safetensors"), - "STABLEWURST_EFFNET": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors?download=true", "models/", "effnet_encoder.safetensors"), - "STABLEWURST_B_LITE": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors?download=true", "models/", "stage_b_lite_bf16.safetensors"), - "STABLEWURST_C": ("https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors?download=true", "models/", "stage_c_bf16.safetensors"), - "ULTRAPIXEL_T2I": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/ultrapixel_t2i.safetensors?download=true", "models/", "ultrapixel_t2i.safetensors"), - "ULTRAPIXEL_LORA_CAT": ("https://huggingface.co/roubaofeipi/UltraPixel/resolve/main/lora_cat.safetensors?download=true", "models/", "lora_cat.safetensors"), - } - - for model, (url, folder, filename) in models.items(): - download_file(url, folder, filename) - -download_models() - -# Global variables -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -dtype = torch.bfloat16 - -# Load configs and setup models -with open("configs/training/t2i.yaml", "r", encoding="utf-8") as file: - config_c = yaml.safe_load(file) - -with open("configs/inference/stage_b_1b.yaml", "r", encoding="utf-8") as file: - config_b = yaml.safe_load(file) - -core = WurstCoreC(config_dict=config_c, device=device, training=False) -core_b = WurstCoreB(config_dict=config_b, device=device, training=False) - -extras = core.setup_extras_pre() -models = core.setup_models(extras) -models.generator.eval().requires_grad_(False) - -extras_b = core_b.setup_extras_pre() -models_b = core_b.setup_models(extras_b, skip_clip=True) -models_b = WurstCoreB.Models( - **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} -) -models_b.generator.bfloat16().eval().requires_grad_(False) - -# Load pretrained model -pretrained_path = "models/ultrapixel_t2i.safetensors" -sdd = torch.load(pretrained_path, map_location='cpu') -collect_sd = {k[7:]: v for k, v in sdd.items()} -models.train_norm.load_state_dict(collect_sd) -models.generator.eval() -models.train_norm.eval() - -# Set up sampling configurations -extras.sampling_configs.update({ - 'cfg': 4, - 'shift': 1, - 'timesteps': 20, - 't_start': 1.0, - 'sampler': DDPMSampler(extras.gdf) -}) - -extras_b.sampling_configs.update({ - 'cfg': 1.1, - 'shift': 1, - 'timesteps': 10, - 't_start': 1.0 -}) - -@spaces.GPU(duration=180) -def generate_images(prompt, height, width, seed, num_images): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - batch_size = num_images - height_lr, width_lr = get_target_lr_size(height / width, std_size=32) - stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) - stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) - - batch = {'captions': [prompt] * batch_size} - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - - with torch.no_grad(): - models.generator.cuda() - with torch.cuda.amp.autocast(dtype=dtype): - sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) - - models.generator.cpu() - torch.cuda.empty_cache() - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - conditions_b['effnet'] = sampled_c - unconditions_b['effnet'] = torch.zeros_like(sampled_c) - - with torch.cuda.amp.autocast(dtype=dtype): - sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=True) - - torch.cuda.empty_cache() - imgs = show_images(sampled) - return imgs - -iface = gr.Interface( - fn=generate_images, - inputs=[ - gr.Textbox(label="Prompt"), - gr.Slider(minimum=256, maximum=2560, step=32, label="Height", value=1024), - gr.Slider(minimum=256, maximum=5120, step=32, label="Width", value=1024), - gr.Number(label="Seed", value=42), - gr.Slider(minimum=1, maximum=10, step=1, label="Number of Images", value=1) - ], - outputs=gr.Gallery(label="Generated Images", columns=5, rows=2), - title="UltraPixel Image Generation", - description="Generate high-resolution images using UltraPixel model.", - theme='bethecloud/storj_theme', - examples=[ - ["The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", 1024, 1024, 42, 1] - ], - cache_examples=True -) - -iface.launch() \ No newline at end of file +from transformers import AutoProcessor, AutoModelForCausalLM +import spaces +from PIL import Image + +import subprocess +subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) + +models = { + 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True).eval(), + 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True).eval(), +} + +processors = { + 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux-Large', trust_remote_code=True), + 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained('gokaygokay/Florence-2-Flux', trust_remote_code=True), +} + +title = """

Florence-2 Captioner for Flux Prompts

+

+[Florence-2 Flux Large] +[Florence-2 Flux Base] +

+""" + +@spaces.GPU +def run_example(image, model_name='gokaygokay/Florence-2-Flux-Large'): + image = Image.fromarray(image) + task_prompt = "" + prompt = task_prompt + "Describe this image in great detail." + + if image.mode != "RGB": + image = image.convert("RGB") + + model = models[model_name] + processor = processors[model_name] + + inputs = processor(text=prompt, images=image, return_tensors="pt") + generated_ids = model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + num_beams=3, + repetition_penalty=1.10, + ) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + parsed_answer = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) + return parsed_answer[""] + +with gr.Blocks(theme='bethecloud/storj_theme') as demo: + gr.HTML(title) + + with gr.Row(): + with gr.Column(): + input_img = gr.Image(label="Input Picture") + model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='gokaygokay/Florence-2-Flux-Large') + submit_btn = gr.Button(value="Submit") + with gr.Column(): + output_text = gr.Textbox(label="Output Text") + + gr.Examples( + [["image1.jpg"], + ["image2.jpg"], + ["image3.png"], + ["image5.jpg"]], + inputs=[input_img, model_selector], + outputs=[output_text], + fn=run_example, + label='Try captioning on below examples' + ) + + submit_btn.click(run_example, [input_img, model_selector], [output_text]) + +demo.launch(debug=True) \ No newline at end of file diff --git a/configs/inference/controlnet_c_3b_canny.yaml b/configs/inference/controlnet_c_3b_canny.yaml deleted file mode 100644 index 286d7a6c8017e922a020d6ae5633cc3e27f9b702..0000000000000000000000000000000000000000 --- a/configs/inference/controlnet_c_3b_canny.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -# ControlNet specific -controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] -controlnet_filter: CannyFilter -controlnet_filter_params: - resize: 224 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -controlnet_checkpoint_path: models/canny.safetensors diff --git a/configs/inference/controlnet_c_3b_identity.yaml b/configs/inference/controlnet_c_3b_identity.yaml deleted file mode 100644 index 8a20fa860fed5f6eea1d33113535c2633205e327..0000000000000000000000000000000000000000 --- a/configs/inference/controlnet_c_3b_identity.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -# ControlNet specific -controlnet_bottleneck_mode: 'simple' -controlnet_blocks: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] -controlnet_filter: IdentityFilter -controlnet_filter_params: - max_faces: 4 - p_drop: 0.00 - p_full: 0.0 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -controlnet_checkpoint_path: diff --git a/configs/inference/controlnet_c_3b_inpainting.yaml b/configs/inference/controlnet_c_3b_inpainting.yaml deleted file mode 100644 index a94bd7953dfa407184d9094b481a56cdbbb73549..0000000000000000000000000000000000000000 --- a/configs/inference/controlnet_c_3b_inpainting.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -# ControlNet specific -controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] -controlnet_filter: InpaintFilter -controlnet_filter_params: - thresold: [0.04, 0.4] - p_outpaint: 0.4 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -controlnet_checkpoint_path: models/inpainting.safetensors diff --git a/configs/inference/controlnet_c_3b_sr.yaml b/configs/inference/controlnet_c_3b_sr.yaml deleted file mode 100644 index 13c4a2cd2dcd2a3cf87fb32bd6e34269e796a747..0000000000000000000000000000000000000000 --- a/configs/inference/controlnet_c_3b_sr.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -# ControlNet specific -controlnet_bottleneck_mode: 'large' -controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] -controlnet_filter: SREffnetFilter -controlnet_filter_params: - scale_factor: 0.5 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -controlnet_checkpoint_path: models/super_resolution.safetensors diff --git a/configs/inference/lora_c_3b.yaml b/configs/inference/lora_c_3b.yaml deleted file mode 100644 index 7468078c657c1f569c6c052a14b265d69082ab25..0000000000000000000000000000000000000000 --- a/configs/inference/lora_c_3b.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -# LoRA specific -module_filters: ['.attn'] -rank: 4 -train_tokens: - # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized - - ['[fernando]', '^dog'] # custom token [snail], initialize as avg of snail & snails - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -lora_checkpoint_path: models/lora_fernando_10k.safetensors diff --git a/configs/inference/stage_b_1b.yaml b/configs/inference/stage_b_1b.yaml deleted file mode 100644 index 0811cae75622614e91de6532262acb2c062bf344..0000000000000000000000000000000000000000 --- a/configs/inference/stage_b_1b.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# GLOBAL STUFF -model_version: 700M -dtype: bfloat16 - -# For demonstration purposes in reconstruct_images.ipynb -webdataset_path: path to your dataset -batch_size: 1 -image_size: 2048 -grad_accum_steps: 1 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -stage_a_checkpoint_path: models/stage_a.safetensors -generator_checkpoint_path: models/stage_b_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_b_3b.yaml b/configs/inference/stage_b_3b.yaml deleted file mode 100644 index 840268980103e0c629599b966705043d6a616578..0000000000000000000000000000000000000000 --- a/configs/inference/stage_b_3b.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# GLOBAL STUFF -model_version: 3B -dtype: bfloat16 - -# For demonstration purposes in reconstruct_images.ipynb -webdataset_path: path to your dataset -batch_size: 4 -image_size: 1024 -grad_accum_steps: 1 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -stage_a_checkpoint_path: models/stage_a.safetensors -generator_checkpoint_path: models/stage_b_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_c_1b.yaml b/configs/inference/stage_c_1b.yaml deleted file mode 100644 index 781886e515d80e7870abb89bf8fd0ce7c7c8d4b6..0000000000000000000000000000000000000000 --- a/configs/inference/stage_c_1b.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# GLOBAL STUFF -model_version: 1B -dtype: bfloat16 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_lite_bf16.safetensors \ No newline at end of file diff --git a/configs/inference/stage_c_3b.yaml b/configs/inference/stage_c_3b.yaml deleted file mode 100644 index b22897e71996ad78f3832af78f5bc44ca06d206d..0000000000000000000000000000000000000000 --- a/configs/inference/stage_c_3b.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# GLOBAL STUFF -model_version: 3.6B -dtype: bfloat16 - -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/configs/training/cfg_control_lr.yaml b/configs/training/cfg_control_lr.yaml deleted file mode 100644 index 2955b6a925504525b981e7004b65a33573c08aef..0000000000000000000000000000000000000000 --- a/configs/training/cfg_control_lr.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# GLOBAL STUFF -experiment_id: Ultrapixel_controlnet - -checkpoint_path: checkpoint output path -output_path: visual results output path -model_version: 3.6B -dtype: float32 -# # WandB -# wandb_project: StableCascade -# wandb_entity: wandb_username -#module_filters: ['.depthwise', '.mapper', '.attn', '.channelwise' ] -#rank: 32 -# TRAINING PARAMS -lr: 1.0e-4 -batch_size: 12 -#image_size: [1536, 2048, 2560, 3072, 4096] -image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] -#image_size: [ 1024, 1536, 2048, 2560, 3072, 3584, 3840, 4096, 4608] -#image_size: [ 1024, 1280] -multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] -grad_accum_steps: 2 -updates: 40000 -backup_every: 5000 -save_every: 256 -warmup_updates: 1 -use_fsdp: True - -# ControlNet specific -controlnet_blocks: [0, 4, 8, 12, 51, 55, 59, 63] -controlnet_filter: CannyFilter -controlnet_filter_params: - resize: 224 -# offset_noise: 0.1 - -# GDF -adaptive_loss_weight: True - -ema_start_iters: 10 -ema_iters: 50 -ema_beta: 0.9 - -webdataset_path: path to your training dataset -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -controlnet_checkpoint_path: pretrained controlnet path - diff --git a/configs/training/lora_personalization.yaml b/configs/training/lora_personalization.yaml deleted file mode 100644 index 857795e6d37e9cb61bd76aa588f432978ed90ad2..0000000000000000000000000000000000000000 --- a/configs/training/lora_personalization.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# GLOBAL STUFF -experiment_id: roubao_cat_personalized - -checkpoint_path: checkpoint output path -output_path: visual results output path -model_version: 3.6B -dtype: float32 - -module_filters: [ '.attn'] -rank: 4 -train_tokens: - # - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized - - ['[roubaobao]', '^cat'] # custom token [snail], initialize as avg of snail & snails -# TRAINING PARAMS -lr: 1.0e-4 -batch_size: 4 - -image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] -multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] -grad_accum_steps: 2 -updates: 40000 -backup_every: 5000 -save_every: 512 -warmup_updates: 1 -use_ddp: True - -# GDF -adaptive_loss_weight: True - - -tmp_prompt: a photo of a cat [roubaobao] -webdataset_path: path to your personalized training dataset -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors -ultrapixel_path: models/ultrapixel_t2i.safetensors - diff --git a/configs/training/t2i.yaml b/configs/training/t2i.yaml deleted file mode 100644 index 8a0ceaca0ad8813e3c9b998661ac3e9b3c0937fd..0000000000000000000000000000000000000000 --- a/configs/training/t2i.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# GLOBAL STUFF -experiment_id: ultrapixel_t2i -#strc_fixlrt_norm3_lite_1024_hrft_newdata -checkpoint_path: checkpoint output path #output model directory -output_path: visual results output path #experiment output directory -model_version: 3.6B # finetune large stage c model of stablecascade -dtype: float32 - - -# TRAINING PARAMS -lr: 1.0e-4 -batch_size: 4 # gpu_number * num_per_gpu * grad_accum_steps -image_size: [1024, 2048, 2560, 3072, 3584, 3840, 4096, 4608] # possible image resolution -multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16] -grad_accum_steps: 2 -updates: 40000 -backup_every: 5000 -save_every: 256 -warmup_updates: 1 -use_ddp: True - -# GDF -adaptive_loss_weight: True - - -webdataset_path: path to your personalized training dataset -effnet_checkpoint_path: models/effnet_encoder.safetensors -previewer_checkpoint_path: models/previewer.safetensors -generator_checkpoint_path: models/stage_c_bf16.safetensors \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index ed382f1907ddc86c7e9a9618c21441755a6221a9..0000000000000000000000000000000000000000 --- a/core/__init__.py +++ /dev/null @@ -1,372 +0,0 @@ -import os -import yaml -import torch -from torch import nn -import wandb -import json -from abc import ABC, abstractmethod -from dataclasses import dataclass -from torch.utils.data import Dataset, DataLoader - -from torch.distributed import init_process_group, destroy_process_group, barrier -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - FullStateDictConfig, - MixedPrecision, - ShardingStrategy, - StateDictType -) - -from .utils import Base, EXPECTED, EXPECTED_TRAIN -from .utils import create_folder_if_necessary, safe_save, load_or_fail - -# pylint: disable=unused-argument -class WarpCore(ABC): - @dataclass(frozen=True) - class Config(Base): - experiment_id: str = EXPECTED_TRAIN - checkpoint_path: str = EXPECTED_TRAIN - output_path: str = EXPECTED_TRAIN - checkpoint_extension: str = "safetensors" - dist_file_subfolder: str = "" - allow_tf32: bool = True - - wandb_project: str = None - wandb_entity: str = None - - @dataclass() # not frozen, means that fields are mutable - class Info(): # not inheriting from Base, because we don't want to enforce the default fields - wandb_run_id: str = None - total_steps: int = 0 - iter: int = 0 - - @dataclass(frozen=True) - class Data(Base): - dataset: Dataset = EXPECTED - dataloader: DataLoader = EXPECTED - iterator: any = EXPECTED - - @dataclass(frozen=True) - class Models(Base): - pass - - @dataclass(frozen=True) - class Optimizers(Base): - pass - - @dataclass(frozen=True) - class Schedulers(Base): - pass - - @dataclass(frozen=True) - class Extras(Base): - pass - # --------------------------------------- - info: Info - config: Config - - # FSDP stuff - fsdp_defaults = { - "sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, - "cpu_offload": None, - "mixed_precision": MixedPrecision( - param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, - ), - "limit_all_gathers": True, - } - fsdp_fullstate_save_policy = FullStateDictConfig( - offload_to_cpu=True, rank0_only=True - ) - # ------------ - - # OVERRIDEABLE METHODS - - # [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup - def setup_extras_pre(self) -> Extras: - return self.Extras() - - # setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator - @abstractmethod - def setup_data(self, extras: Extras) -> Data: - raise NotImplementedError("This method needs to be overriden") - - # return a dict with all models that are going to be used in the training - @abstractmethod - def setup_models(self, extras: Extras) -> Models: - raise NotImplementedError("This method needs to be overriden") - - # return a dict with all optimizers that are going to be used in the training - @abstractmethod - def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: - raise NotImplementedError("This method needs to be overriden") - - # [optionally] return a dict with all schedulers that are going to be used in the training - def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: - return self.Schedulers() - - # [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup - def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras: - return self.Extras.from_dict(extras.to_dict()) - - # perform the training here - @abstractmethod - def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): - raise NotImplementedError("This method needs to be overriden") - # ------------ - - def setup_info(self, full_path=None) -> Info: - if full_path is None: - full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json") - info_dict = load_or_fail(full_path, wandb_run_id=None) or {} - info_dto = self.Info(**info_dict) - if info_dto.total_steps > 0 and self.is_main_node: - print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps) - return info_dto - - def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config: - if config_file_path is not None: - if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"): - with open(config_file_path, "r", encoding="utf-8") as file: - loaded_config = yaml.safe_load(file) - elif config_file_path.endswith(".json"): - with open(config_file_path, "r", encoding="utf-8") as file: - loaded_config = json.load(file) - else: - raise ValueError("Config file must be either a .yml|.yaml or .json file") - return self.Config.from_dict({**loaded_config, 'training': training}) - if config_dict is not None: - return self.Config.from_dict({**config_dict, 'training': training}) - return self.Config(training=training) - - def setup_ddp(self, experiment_id, single_gpu=False): - if not single_gpu: - local_rank = int(os.environ.get("SLURM_LOCALID")) - process_id = int(os.environ.get("SLURM_PROCID")) - world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count() - - self.process_id = process_id - self.is_main_node = process_id == 0 - self.device = torch.device(local_rank) - self.world_size = world_size - - dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}" - # if os.path.exists(dist_file_path) and self.is_main_node: - # os.remove(dist_file_path) - - torch.cuda.set_device(local_rank) - init_process_group( - backend="nccl", - rank=process_id, - world_size=world_size, - init_method=f"file://{dist_file_path}", - ) - print(f"[GPU {process_id}] READY") - else: - print("Running in single thread, DDP not enabled.") - - def setup_wandb(self): - if self.is_main_node and self.config.wandb_project is not None: - self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id() - wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict()) - - if self.info.total_steps > 0: - wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}") - else: - wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started") - - # LOAD UTILITIES ---------- - def load_model(self, model, model_id=None, full_path=None, strict=True): - print('in line 181 load model', type(model), model_id, full_path, strict) - if model_id is not None and full_path is None: - full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" - elif full_path is None and model_id is None: - raise ValueError( - "This method expects either 'model_id' or 'full_path' to be defined" - ) - - checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) - if checkpoint is not None: - model.load_state_dict(checkpoint, strict=strict) - del checkpoint - - return model - - def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): - if optim_id is not None and full_path is None: - full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" - elif full_path is None and optim_id is None: - raise ValueError( - "This method expects either 'optim_id' or 'full_path' to be defined" - ) - - checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) - if checkpoint is not None: - try: - if fsdp_model is not None: - sharded_optimizer_state_dict = ( - FSDP.scatter_full_optim_state_dict( # <---- FSDP - checkpoint - if ( - self.is_main_node - or self.fsdp_defaults["sharding_strategy"] - == ShardingStrategy.NO_SHARD - ) - else None, - fsdp_model, - ) - ) - optim.load_state_dict(sharded_optimizer_state_dict) - del checkpoint, sharded_optimizer_state_dict - else: - optim.load_state_dict(checkpoint) - # pylint: disable=broad-except - except Exception as e: - print("!!! Failed loading optimizer, skipping... Exception:", e) - - return optim - - # SAVE UTILITIES ---------- - def save_info(self, info, suffix=""): - full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json" - create_folder_if_necessary(full_path) - if self.is_main_node: - safe_save(vars(self.info), full_path) - - def save_model(self, model, model_id=None, full_path=None, is_fsdp=False): - if model_id is not None and full_path is None: - full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" - elif full_path is None and model_id is None: - raise ValueError( - "This method expects either 'model_id' or 'full_path' to be defined" - ) - create_folder_if_necessary(full_path) - if is_fsdp: - with FSDP.summon_full_params(model): - pass - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy - ): - checkpoint = model.state_dict() - if self.is_main_node: - safe_save(checkpoint, full_path) - del checkpoint - else: - if self.is_main_node: - checkpoint = model.state_dict() - safe_save(checkpoint, full_path) - del checkpoint - - def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): - if optim_id is not None and full_path is None: - full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" - elif full_path is None and optim_id is None: - raise ValueError( - "This method expects either 'optim_id' or 'full_path' to be defined" - ) - create_folder_if_necessary(full_path) - if fsdp_model is not None: - optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) - if self.is_main_node: - safe_save(optim_statedict, full_path) - del optim_statedict - else: - if self.is_main_node: - checkpoint = optim.state_dict() - safe_save(checkpoint, full_path) - del checkpoint - # ----- - - def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True): - # Temporary setup, will be overriden by setup_ddp if required - self.device = device - self.process_id = 0 - self.is_main_node = True - self.world_size = 1 - # ---- - - self.config: self.Config = self.setup_config(config_file_path, config_dict, training) - self.info: self.Info = self.setup_info() - - def __call__(self, single_gpu=False): - self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank - self.setup_wandb() - if self.config.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - if self.is_main_node: - print() - print("**STARTIG JOB WITH CONFIG:**") - print(yaml.dump(self.config.to_dict(), default_flow_style=False)) - print("------------------------------------") - print() - print("**INFO:**") - print(yaml.dump(vars(self.info), default_flow_style=False)) - print("------------------------------------") - print() - - # SETUP STUFF - extras = self.setup_extras_pre() - assert extras is not None, "setup_extras_pre() must return a DTO" - - data = self.setup_data(extras) - assert data is not None, "setup_data() must return a DTO" - if self.is_main_node: - print("**DATA:**") - print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - models = self.setup_models(extras) - assert models is not None, "setup_models() must return a DTO" - if self.is_main_node: - print("**MODELS:**") - print(yaml.dump({ - k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() - }, default_flow_style=False)) - print("------------------------------------") - print() - - optimizers = self.setup_optimizers(extras, models) - assert optimizers is not None, "setup_optimizers() must return a DTO" - if self.is_main_node: - print("**OPTIMIZERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - schedulers = self.setup_schedulers(extras, models, optimizers) - assert schedulers is not None, "setup_schedulers() must return a DTO" - if self.is_main_node: - print("**SCHEDULERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) - assert post_extras is not None, "setup_extras_post() must return a DTO" - extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) - if self.is_main_node: - print("**EXTRAS:**") - print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - # ------- - - # TRAIN - if self.is_main_node: - print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) - - if single_gpu is False: - barrier() - destroy_process_group() - if self.is_main_node: - print() - print("------------------------------------") - print() - print("**TRAINING COMPLETE**") - if self.config.wandb_project is not None: - wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") diff --git a/core/data/__init__.py b/core/data/__init__.py deleted file mode 100644 index b687719914b2e303909f7c280347e4bdee607d13..0000000000000000000000000000000000000000 --- a/core/data/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -import json -import subprocess -import yaml -import os -from .bucketeer import Bucketeer - -class MultiFilter(): - def __init__(self, rules, default=False): - self.rules = rules - self.default = default - - def __call__(self, x): - try: - x_json = x['json'] - if isinstance(x_json, bytes): - x_json = json.loads(x_json) - validations = [] - for k, r in self.rules.items(): - if isinstance(k, tuple): - v = r(*[x_json[kv] for kv in k]) - else: - v = r(x_json[k]) - validations.append(v) - return all(validations) - except Exception: - return False - -class MultiGetter(): - def __init__(self, rules): - self.rules = rules - - def __call__(self, x_json): - if isinstance(x_json, bytes): - x_json = json.loads(x_json) - outputs = [] - for k, r in self.rules.items(): - if isinstance(k, tuple): - v = r(*[x_json[kv] for kv in k]) - else: - v = r(x_json[k]) - outputs.append(v) - if len(outputs) == 1: - outputs = outputs[0] - return outputs - -def setup_webdataset_path(paths, cache_path=None): - if cache_path is None or not os.path.exists(cache_path): - tar_paths = [] - if isinstance(paths, str): - paths = [paths] - for path in paths: - if path.strip().endswith(".tar"): - # Avoid looking up s3 if we already have a tar file - tar_paths.append(path) - continue - bucket = "/".join(path.split("/")[:3]) - result = subprocess.run([f"aws s3 ls {path} --recursive | awk '{{print $4}}'"], stdout=subprocess.PIPE, shell=True, check=True) - files = result.stdout.decode('utf-8').split() - files = [f"{bucket}/{f}" for f in files if f.endswith(".tar")] - tar_paths += files - - with open(cache_path, 'w', encoding='utf-8') as outfile: - yaml.dump(tar_paths, outfile, default_flow_style=False) - else: - with open(cache_path, 'r', encoding='utf-8') as file: - tar_paths = yaml.safe_load(file) - - tar_paths_str = ",".join([f"{p}" for p in tar_paths]) - return f"pipe:aws s3 cp {{ {tar_paths_str} }} -" diff --git a/core/data/bucketeer.py b/core/data/bucketeer.py deleted file mode 100644 index 131e6ba4293bd7c00399f08609aba184b712d5e8..0000000000000000000000000000000000000000 --- a/core/data/bucketeer.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import torchvision -import numpy as np -from torchtools.transforms import SmartCrop -import math - -class Bucketeer(): - def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): - assert crop_mode in ['center', 'random', 'smart'] - self.crop_mode = crop_mode - self.ratios = ratios - if reverse_list: - for r in list(ratios): - if 1/r not in self.ratios: - self.ratios.append(1/r) - self.sizes = {} - for dd in density: - self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] - - self.batch_size = dataloader.batch_size - self.iterator = iter(dataloader) - all_sizes = [] - for k, vs in self.sizes.items(): - all_sizes += vs - self.buckets = {s: [] for s in all_sizes} - self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None - self.p_random_ratio = p_random_ratio - self.interpolate_nearest = interpolate_nearest - - def get_available_batch(self): - for b in self.buckets: - if len(self.buckets[b]) >= self.batch_size: - batch = self.buckets[b][:self.batch_size] - self.buckets[b] = self.buckets[b][self.batch_size:] - return batch - return None - - def get_closest_size(self, x): - w, h = x.size(-1), x.size(-2) - - - best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) - find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} - min_ = find_dict[list(find_dict.keys())[0]] - find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] - for dd, val in find_dict.items(): - if val < min_: - min_ = val - find_size = self.sizes[dd][best_size_idx] - - return find_size - - def get_resize_size(self, orig_size, tgt_size): - if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: - alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) - resize_size = max(alt_min, min(tgt_size)) - else: - alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) - resize_size = max(alt_max, max(tgt_size)) - - return resize_size - - def __next__(self): - batch = self.get_available_batch() - while batch is None: - elements = next(self.iterator) - for dct in elements: - img = dct['images'] - size = self.get_closest_size(img) - resize_size = self.get_resize_size(img.shape[-2:], size) - - if self.interpolate_nearest: - img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) - else: - img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) - if self.crop_mode == 'center': - img = torchvision.transforms.functional.center_crop(img, size) - elif self.crop_mode == 'random': - img = torchvision.transforms.RandomCrop(size)(img) - elif self.crop_mode == 'smart': - self.smartcrop.output_size = size - img = self.smartcrop(img) - - self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) - batch = self.get_available_batch() - - out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} - return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/bucketeer_deg.py b/core/data/bucketeer_deg.py deleted file mode 100644 index 7206ccf08932f617abb811221cc7bbe1d126f184..0000000000000000000000000000000000000000 --- a/core/data/bucketeer_deg.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torchvision -import numpy as np -from torchtools.transforms import SmartCrop -import math - -class Bucketeer(): - def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): - assert crop_mode in ['center', 'random', 'smart'] - self.crop_mode = crop_mode - self.ratios = ratios - if reverse_list: - for r in list(ratios): - if 1/r not in self.ratios: - self.ratios.append(1/r) - self.sizes = {} - for dd in density: - self.sizes[dd]= [(int(((dd/r)**0.5//factor)*factor), int(((dd*r)**0.5//factor)*factor)) for r in ratios] - print('in line 17 buckteer', self.sizes) - self.batch_size = dataloader.batch_size - self.iterator = iter(dataloader) - all_sizes = [] - for k, vs in self.sizes.items(): - all_sizes += vs - self.buckets = {s: [] for s in all_sizes} - self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None - self.p_random_ratio = p_random_ratio - self.interpolate_nearest = interpolate_nearest - - def get_available_batch(self): - for b in self.buckets: - if len(self.buckets[b]) >= self.batch_size: - batch = self.buckets[b][:self.batch_size] - self.buckets[b] = self.buckets[b][self.batch_size:] - return batch - return None - - def get_closest_size(self, x): - w, h = x.size(-1), x.size(-2) - #if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: - # best_size_idx = np.random.randint(len(self.ratios)) - #print('in line 41 get closes size', best_size_idx, x.shape, self.p_random_ratio) - #else: - - best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) - find_dict = {dd : abs(w*h - self.sizes[dd][best_size_idx][0]*self.sizes[dd][best_size_idx][1]) for dd, vv in self.sizes.items()} - min_ = find_dict[list(find_dict.keys())[0]] - find_size = self.sizes[list(find_dict.keys())[0]][best_size_idx] - for dd, val in find_dict.items(): - if val < min_: - min_ = val - find_size = self.sizes[dd][best_size_idx] - - return find_size - - def get_resize_size(self, orig_size, tgt_size): - if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: - alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) - resize_size = max(alt_min, min(tgt_size)) - else: - alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) - resize_size = max(alt_max, max(tgt_size)) - #print('in line 50', orig_size, tgt_size, resize_size) - return resize_size - - def __next__(self): - batch = self.get_available_batch() - while batch is None: - elements = next(self.iterator) - for dct in elements: - img = dct['images'] - size = self.get_closest_size(img) - resize_size = self.get_resize_size(img.shape[-2:], size) - #print('in line 74', img.size(), resize_size) - if self.interpolate_nearest: - img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) - else: - img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) - if self.crop_mode == 'center': - img = torchvision.transforms.functional.center_crop(img, size) - elif self.crop_mode == 'random': - img = torchvision.transforms.RandomCrop(size)(img) - elif self.crop_mode == 'smart': - self.smartcrop.output_size = size - img = self.smartcrop(img) - print('in line 86 bucketeer', type(img), img.shape, torch.max(img), torch.min(img)) - self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) - batch = self.get_available_batch() - - out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} - return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} diff --git a/core/data/deg_kair_utils/utils_alignfaces.py b/core/data/deg_kair_utils/utils_alignfaces.py deleted file mode 100644 index fa74e8a2e8984f5075d0cbd06afd494c9661a015..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_alignfaces.py +++ /dev/null @@ -1,263 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Mon Apr 24 15:43:29 2017 -@author: zhaoy -""" -import cv2 -import numpy as np -from skimage import transform as trans - -# reference facial points, a list of coordinates (x,y) -REFERENCE_FACIAL_POINTS = [ - [30.29459953, 51.69630051], - [65.53179932, 51.50139999], - [48.02519989, 71.73660278], - [33.54930115, 92.3655014], - [62.72990036, 92.20410156] -] - -DEFAULT_CROP_SIZE = (96, 112) - - -def _umeyama(src, dst, estimate_scale=True, scale=1.0): - """Estimate N-D similarity transformation with or without scaling. - Parameters - ---------- - src : (M, N) array - Source coordinates. - dst : (M, N) array - Destination coordinates. - estimate_scale : bool - Whether to estimate scaling factor. - Returns - ------- - T : (N + 1, N + 1) - The homogeneous similarity transformation matrix. The matrix contains - NaN values only if the problem is not well-conditioned. - References - ---------- - .. [1] "Least-squares estimation of transformation parameters between two - point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` - """ - - num = src.shape[0] - dim = src.shape[1] - - # Compute mean of src and dst. - src_mean = src.mean(axis=0) - dst_mean = dst.mean(axis=0) - - # Subtract mean from src and dst. - src_demean = src - src_mean - dst_demean = dst - dst_mean - - # Eq. (38). - A = dst_demean.T @ src_demean / num - - # Eq. (39). - d = np.ones((dim,), dtype=np.double) - if np.linalg.det(A) < 0: - d[dim - 1] = -1 - - T = np.eye(dim + 1, dtype=np.double) - - U, S, V = np.linalg.svd(A) - - # Eq. (40) and (43). - rank = np.linalg.matrix_rank(A) - if rank == 0: - return np.nan * T - elif rank == dim - 1: - if np.linalg.det(U) * np.linalg.det(V) > 0: - T[:dim, :dim] = U @ V - else: - s = d[dim - 1] - d[dim - 1] = -1 - T[:dim, :dim] = U @ np.diag(d) @ V - d[dim - 1] = s - else: - T[:dim, :dim] = U @ np.diag(d) @ V - - if estimate_scale: - # Eq. (41) and (42). - scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) - else: - scale = scale - - T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) - T[:dim, :dim] *= scale - - return T, scale - - -class FaceWarpException(Exception): - def __str__(self): - return 'In File {}:{}'.format( - __file__, super.__str__(self)) - - -def get_reference_facial_points(output_size=None, - inner_padding_factor=0.0, - outer_padding=(0, 0), - default_square=False): - tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) - tmp_crop_size = np.array(DEFAULT_CROP_SIZE) - - # 0) make the inner region a square - if default_square: - size_diff = max(tmp_crop_size) - tmp_crop_size - tmp_5pts += size_diff / 2 - tmp_crop_size += size_diff - - if (output_size and - output_size[0] == tmp_crop_size[0] and - output_size[1] == tmp_crop_size[1]): - print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size)) - return tmp_5pts - - if (inner_padding_factor == 0 and - outer_padding == (0, 0)): - if output_size is None: - print('No paddings to do: return default reference points') - return tmp_5pts - else: - raise FaceWarpException( - 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) - - # check output size - if not (0 <= inner_padding_factor <= 1.0): - raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') - - if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) - and output_size is None): - output_size = tmp_crop_size * \ - (1 + inner_padding_factor * 2).astype(np.int32) - output_size += np.array(outer_padding) - print(' deduced from paddings, output_size = ', output_size) - - if not (outer_padding[0] < output_size[0] - and outer_padding[1] < output_size[1]): - raise FaceWarpException('Not (outer_padding[0] < output_size[0]' - 'and outer_padding[1] < output_size[1])') - - # 1) pad the inner region according inner_padding_factor - # print('---> STEP1: pad the inner region according inner_padding_factor') - if inner_padding_factor > 0: - size_diff = tmp_crop_size * inner_padding_factor * 2 - tmp_5pts += size_diff / 2 - tmp_crop_size += np.round(size_diff).astype(np.int32) - - # print(' crop_size = ', tmp_crop_size) - # print(' reference_5pts = ', tmp_5pts) - - # 2) resize the padded inner region - # print('---> STEP2: resize the padded inner region') - size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 - # print(' crop_size = ', tmp_crop_size) - # print(' size_bf_outer_pad = ', size_bf_outer_pad) - - if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: - raise FaceWarpException('Must have (output_size - outer_padding)' - '= some_scale * (crop_size * (1.0 + inner_padding_factor)') - - scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] - # print(' resize scale_factor = ', scale_factor) - tmp_5pts = tmp_5pts * scale_factor - # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) - # tmp_5pts = tmp_5pts + size_diff / 2 - tmp_crop_size = size_bf_outer_pad - # print(' crop_size = ', tmp_crop_size) - # print(' reference_5pts = ', tmp_5pts) - - # 3) add outer_padding to make output_size - reference_5point = tmp_5pts + np.array(outer_padding) - tmp_crop_size = output_size - # print('---> STEP3: add outer_padding to make output_size') - # print(' crop_size = ', tmp_crop_size) - # print(' reference_5pts = ', tmp_5pts) - # - # print('===> end get_reference_facial_points\n') - - return reference_5point - - -def get_affine_transform_matrix(src_pts, dst_pts): - tfm = np.float32([[1, 0, 0], [0, 1, 0]]) - n_pts = src_pts.shape[0] - ones = np.ones((n_pts, 1), src_pts.dtype) - src_pts_ = np.hstack([src_pts, ones]) - dst_pts_ = np.hstack([dst_pts, ones]) - - A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) - - if rank == 3: - tfm = np.float32([ - [A[0, 0], A[1, 0], A[2, 0]], - [A[0, 1], A[1, 1], A[2, 1]] - ]) - elif rank == 2: - tfm = np.float32([ - [A[0, 0], A[1, 0], 0], - [A[0, 1], A[1, 1], 0] - ]) - - return tfm - - -def warp_and_crop_face(src_img, - facial_pts, - reference_pts=None, - crop_size=(96, 112), - align_type='smilarity'): #smilarity cv2_affine affine - if reference_pts is None: - if crop_size[0] == 96 and crop_size[1] == 112: - reference_pts = REFERENCE_FACIAL_POINTS - else: - default_square = False - inner_padding_factor = 0 - outer_padding = (0, 0) - output_size = crop_size - - reference_pts = get_reference_facial_points(output_size, - inner_padding_factor, - outer_padding, - default_square) - - ref_pts = np.float32(reference_pts) - ref_pts_shp = ref_pts.shape - if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: - raise FaceWarpException( - 'reference_pts.shape must be (K,2) or (2,K) and K>2') - - if ref_pts_shp[0] == 2: - ref_pts = ref_pts.T - - src_pts = np.float32(facial_pts) - src_pts_shp = src_pts.shape - if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: - raise FaceWarpException( - 'facial_pts.shape must be (K,2) or (2,K) and K>2') - - if src_pts_shp[0] == 2: - src_pts = src_pts.T - - if src_pts.shape != ref_pts.shape: - raise FaceWarpException( - 'facial_pts and reference_pts must have the same shape') - - if align_type is 'cv2_affine': - tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) - tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) - elif align_type is 'affine': - tfm = get_affine_transform_matrix(src_pts, ref_pts) - tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) - else: - params, scale = _umeyama(src_pts, ref_pts) - tfm = params[:2, :] - - params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale) - tfm_inv = params[:2, :] - - face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) - - return face_img, tfm_inv diff --git a/core/data/deg_kair_utils/utils_blindsr.py b/core/data/deg_kair_utils/utils_blindsr.py deleted file mode 100644 index 9a1a7baf99473043e216c16f464f4e168cbd94ab..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_blindsr.py +++ /dev/null @@ -1,631 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from core.data.deg_kair_utils import utils_image as util - -import random -from scipy import ndimage -import scipy -import scipy.stats as ss -from scipy.interpolate import interp2d -from scipy.linalg import orth - - - - -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# From 2019/03--2021/08 -# -------------------------------------------- -""" - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" -def analytic_kernel(k): - """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" - k_size = k.shape[0] - # Calculate the big kernels size - big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) - # Loop over the small kernel to fill the big one - for r in range(k_size): - for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k - # Crop the edges of the big kernel to ignore very small values and increase run time of SR - crop = k_size // 2 - cropped_big_k = big_k[crop:-crop, crop:-crop] - # Normalize to 1 - return cropped_big_k / cropped_big_k.sum() - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf-1)*0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w-1) - y1 = np.clip(y1, 0, h-1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -def blur(x, k): - ''' - x: image, NxcxHxW - k: kernel, Nx1xhxw - ''' - n, c = x.shape[:2] - p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') - k = k.repeat(1,c,1,1) - k = k.view(-1, 1, k.shape[2], k.shape[3]) - x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c) - x = x.view(n, c, x.shape[2], x.shape[3]) - - return x - - - -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z-MU - ZZ_t = ZZ.transpose(0,1,3,2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) - arg = -(x*x + y*y)/(2*std*std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h/sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha,1])]) - h1 = alpha/(alpha+1) - h2 = (1-alpha)/(alpha+1) - h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1/sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - - Return: - downsampled LR image - - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - - ''' bicubic downsampling + blur - - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - - Return: - downsampled LR image - - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - - Return: - downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def add_sharpening(img, weight=0.5, radius=50, threshold=10): - """USM sharpening. borrowed from real-ESRGAN - Input image: I; Blurry image: B. - 1. K = I + weight * (I - B) - 2. Mask = 1 if abs(I - B) > threshold, else: 0 - 3. Blur mask: - 4. Out = Mask * K + (1 - Mask) * I - Args: - img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. - weight (float): Sharp weight. Default: 1. - radius (float): Kernel size of Gaussian blur. Default: 50. - threshold (int): - """ - if radius % 2 == 0: - radius += 1 - blur = cv2.GaussianBlur(img, (radius, radius), 0) - residual = img - blur - mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') - soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) - - K = img + weight * residual - K = np.clip(K, 0, 1) - return soft_mask * K + (1 - soft_mask) * img - - -def add_blur(img, sf=4): - wd2 = 4.0 + sf - wd = 2.0 + 0.2*sf - if random.random() < 0.5: - l1 = wd2*random.random() - l2 = wd2*random.random() - k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2) - else: - k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') - - return img - - -def add_resize(img, sf=4): - rnum = np.random.rand() - if rnum > 0.8: # up - sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down - sf1 = random.uniform(0.5/sf, 1) - else: - sf1 = 1.0 - img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - return img - - -def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise - img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise - img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise - L = noise_level2/255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3,3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_speckle_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - img = np.clip(img, 0.0, 1.0) - rnum = random.random() - if rnum > 0.6: - img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) - elif rnum < 0.4: - img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) - else: - L = noise_level2/255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3,3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10**(2*random.random()+2.0) # [2, 4] - if random.random() < 0.5: - img = np.random.poisson(img * vals).astype(np.float32) / vals - else: - img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - img += noise_gray[:, :, np.newaxis] - img = np.clip(img, 0.0, 1.0) - return img - - -def add_JPEG_noise(img): - quality_factor = random.randint(30, 95) - img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) - img = cv2.imdecode(encimg, 1) - img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) - return img - - -def random_crop(lq, hq, sf=4, lq_patchsize=64): - h, w = lq.shape[:2] - rnd_h = random.randint(0, h-lq_patchsize) - rnd_w = random.randint(0, w-lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] - - rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :] - return lq, hq - - -def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - isp_model: camera ISP model - - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = img.shape[:2] - img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize*sf or w < lq_patchsize*sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - hq = img.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3])) - else: - img = util.imresize_np(img, 1/2, True) - img = np.clip(img, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: - img = add_blur(img, sf=sf) - - elif i == 2: - a, b = img.shape[1], img.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1,2*sf) - img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling - img = np.clip(img, 0.0, 1.0) - - elif i == 3: - # downsample3 - img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3])) - img = np.clip(img, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - img = add_JPEG_noise(img) - - elif i == 6: - # add processed camera sensor noise - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf_ori, lq_patchsize) - - return img, hq - - - - -def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=False, lq_patchsize=64, isp_model=None): - """ - This is an extended degradation model by combining - the degradation models of BSRGAN and Real-ESRGAN - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - use_shuffle: the degradation shuffle - use_sharp: sharpening the img - - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - - h1, w1 = img.shape[:2] - img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize*sf or w < lq_patchsize*sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - if use_sharp: - img = add_sharpening(img) - hq = img.copy() - - if random.random() < shuffle_prob: - shuffle_order = random.sample(range(13), 13) - else: - shuffle_order = list(range(13)) - # local shuffle for noise, JPEG is always the last one - shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) - shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) - - poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 - - for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - elif i == 1: - img = add_resize(img, sf=sf) - elif i == 2: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 3: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 4: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 5: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - elif i == 6: - img = add_JPEG_noise(img) - elif i == 7: - img = add_blur(img, sf=sf) - elif i == 8: - img = add_resize(img, sf=sf) - elif i == 9: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 10: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 11: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 12: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - else: - print('check the shuffle!') - - # resize to desired size - img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3])) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf, lq_patchsize) - - return img, hq - - - -if __name__ == '__main__': - img = util.imread_uint('utils/test.png', 3) - img = util.uint2single(img) - sf = 4 - - for i in range(20): - img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72) - print(i) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) - img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i)+'.png') - -# for i in range(10): -# img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64) -# print(i) -# lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) -# img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) -# util.imsave(img_concat, str(i)+'.png') - -# run utils/utils_blindsr.py diff --git a/core/data/deg_kair_utils/utils_bnorm.py b/core/data/deg_kair_utils/utils_bnorm.py deleted file mode 100644 index 9bd346e05b66efd074f81f1961068e2de45ac5da..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_bnorm.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.nn as nn - - -""" -# -------------------------------------------- -# Batch Normalization -# -------------------------------------------- - -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# 01/Jan/2019 -# -------------------------------------------- -""" - - -# -------------------------------------------- -# remove/delete specified layer -# -------------------------------------------- -def deleteLayer(model, layer_type=nn.BatchNorm2d): - ''' Kai Zhang, 11/Jan/2019. - ''' - for k, m in list(model.named_children()): - if isinstance(m, layer_type): - del model._modules[k] - deleteLayer(m, layer_type) - - -# -------------------------------------------- -# merge bn, "conv+bn" --> "conv" -# -------------------------------------------- -def merge_bn(model): - ''' Kai Zhang, 11/Jan/2019. - merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') - based on https://github.com/pytorch/pytorch/pull/901 - ''' - prev_m = None - for k, m in list(model.named_children()): - if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): - - w = prev_m.weight.data - - if prev_m.bias is None: - zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) - prev_m.bias = nn.Parameter(zeros) - b = prev_m.bias.data - - invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) - if isinstance(prev_m, nn.ConvTranspose2d): - w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) - else: - w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) - b.add_(-m.running_mean).mul_(invstd) - if m.affine: - if isinstance(prev_m, nn.ConvTranspose2d): - w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) - else: - w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) - b.mul_(m.weight.data).add_(m.bias.data) - - del model._modules[k] - prev_m = m - merge_bn(m) - - -# -------------------------------------------- -# add bn, "conv" --> "conv+bn" -# -------------------------------------------- -def add_bn(model): - ''' Kai Zhang, 11/Jan/2019. - ''' - for k, m in list(model.named_children()): - if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): - b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) - b.weight.data.fill_(1) - new_m = nn.Sequential(model._modules[k], b) - model._modules[k] = new_m - add_bn(m) - - -# -------------------------------------------- -# tidy model after removing bn -# -------------------------------------------- -def tidy_sequential(model): - ''' Kai Zhang, 11/Jan/2019. - ''' - for k, m in list(model.named_children()): - if isinstance(m, nn.Sequential): - if m.__len__() == 1: - model._modules[k] = m.__getitem__(0) - tidy_sequential(m) diff --git a/core/data/deg_kair_utils/utils_deblur.py b/core/data/deg_kair_utils/utils_deblur.py deleted file mode 100644 index 8ab5852d0cb334627abcd9476409d632740be389..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_deblur.py +++ /dev/null @@ -1,655 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import scipy -from scipy import fftpack -import torch - -from math import cos, sin -from numpy import zeros, ones, prod, array, pi, log, min, mod, arange, sum, mgrid, exp, pad, round -from numpy.random import randn, rand -from scipy.signal import convolve2d -import cv2 -import random -# import utils_image as util - -''' -modified by Kai Zhang (github: https://github.com/cszn) -03/03/2019 -''' - - -def get_uperleft_denominator(img, kernel): - ''' - img: HxWxC - kernel: hxw - denominator: HxWx1 - upperleft: HxWxC - ''' - V = psf2otf(kernel, img.shape[:2]) - denominator = np.expand_dims(np.abs(V)**2, axis=2) - upperleft = np.expand_dims(np.conj(V), axis=2) * np.fft.fft2(img, axes=[0, 1]) - return upperleft, denominator - - -def get_uperleft_denominator_pytorch(img, kernel): - ''' - img: NxCxHxW - kernel: Nx1xhxw - denominator: Nx1xHxW - upperleft: NxCxHxWx2 - ''' - V = p2o(kernel, img.shape[-2:]) # Nx1xHxWx2 - denominator = V[..., 0]**2+V[..., 1]**2 # Nx1xHxW - upperleft = cmul(cconj(V), rfft(img)) # Nx1xHxWx2 * NxCxHxWx2 - return upperleft, denominator - - -def c2c(x): - return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) - - -def r2c(x): - return torch.stack([x, torch.zeros_like(x)], -1) - - -def cdiv(x, y): - a, b = x[..., 0], x[..., 1] - c, d = y[..., 0], y[..., 1] - cd2 = c**2 + d**2 - return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) - - -def cabs(x): - return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) - - -def cmul(t1, t2): - ''' - complex multiplication - t1: NxCxHxWx2 - output: NxCxHxWx2 - ''' - real1, imag1 = t1[..., 0], t1[..., 1] - real2, imag2 = t2[..., 0], t2[..., 1] - return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) - - -def cconj(t, inplace=False): - ''' - # complex's conjugation - t: NxCxHxWx2 - output: NxCxHxWx2 - ''' - c = t.clone() if not inplace else t - c[..., 1] *= -1 - return c - - -def rfft(t): - return torch.rfft(t, 2, onesided=False) - - -def irfft(t): - return torch.irfft(t, 2, onesided=False) - - -def fft(t): - return torch.fft(t, 2) - - -def ifft(t): - return torch.ifft(t, 2) - - -def p2o(psf, shape): - ''' - # psf: NxCxhxw - # shape: [H,W] - # otf: NxCxHxWx2 - ''' - otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) - otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) - for axis, axis_size in enumerate(psf.shape[2:]): - otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) - otf = torch.rfft(otf, 2, onesided=False) - n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) - otf[...,1][torch.abs(otf[...,1])= abs(y)] = abs(x)[abs(x) >= abs(y)] - maxxy[abs(y) >= abs(x)] = abs(y)[abs(y) >= abs(x)] - minxy = np.zeros(x.shape) - minxy[abs(x) <= abs(y)] = abs(x)[abs(x) <= abs(y)] - minxy[abs(y) <= abs(x)] = abs(y)[abs(y) <= abs(x)] - m1 = (rad**2 < (maxxy+0.5)**2 + (minxy-0.5)**2)*(minxy-0.5) +\ - (rad**2 >= (maxxy+0.5)**2 + (minxy-0.5)**2)*\ - np.sqrt((rad**2 + 0j) - (maxxy + 0.5)**2) - m2 = (rad**2 > (maxxy-0.5)**2 + (minxy+0.5)**2)*(minxy+0.5) +\ - (rad**2 <= (maxxy-0.5)**2 + (minxy+0.5)**2)*\ - np.sqrt((rad**2 + 0j) - (maxxy - 0.5)**2) - h = None - return h - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) - arg = -(x*x + y*y)/(2*std*std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h/sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha,1])]) - h1 = alpha/(alpha+1) - h2 = (1-alpha)/(alpha+1) - h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial_log(hsize, sigma): - raise(NotImplemented) - - -def fspecial_motion(motion_len, theta): - raise(NotImplemented) - - -def fspecial_prewitt(): - return np.array([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) - - -def fspecial_sobel(): - return np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'average': - return fspecial_average(*args, **kwargs) - if filter_type == 'disk': - return fspecial_disk(*args, **kwargs) - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - if filter_type == 'log': - return fspecial_log(*args, **kwargs) - if filter_type == 'motion': - return fspecial_motion(*args, **kwargs) - if filter_type == 'prewitt': - return fspecial_prewitt(*args, **kwargs) - if filter_type == 'sobel': - return fspecial_sobel(*args, **kwargs) - - -def fspecial_gauss(size, sigma): - x, y = mgrid[-size // 2 + 1 : size // 2 + 1, -size // 2 + 1 : size // 2 + 1] - g = exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) - return g / g.sum() - - -def blurkernel_synthesis(h=37, w=None): - # https://github.com/tkkcc/prior/blob/879a0b6c117c810776d8cc6b63720bf29f7d0cc4/util/gen_kernel.py - w = h if w is None else w - kdims = [h, w] - x = randomTrajectory(250) - k = None - while k is None: - k = kernelFromTrajectory(x) - - # center pad to kdims - pad_width = ((kdims[0] - k.shape[0]) // 2, (kdims[1] - k.shape[1]) // 2) - pad_width = [(pad_width[0],), (pad_width[1],)] - - if pad_width[0][0]<0 or pad_width[1][0]<0: - k = k[0:h, 0:h] - else: - k = pad(k, pad_width, "constant") - x1,x2 = k.shape - if np.random.randint(0, 4) == 1: - k = cv2.resize(k, (random.randint(x1, 5*x1), random.randint(x2, 5*x2)), interpolation=cv2.INTER_LINEAR) - y1, y2 = k.shape - k = k[(y1-x1)//2: (y1-x1)//2+x1, (y2-x2)//2: (y2-x2)//2+x2] - - if sum(k)<0.1: - k = fspecial_gaussian(h, 0.1+6*np.random.rand(1)) - k = k / sum(k) - # import matplotlib.pyplot as plt - # plt.imshow(k, interpolation="nearest", cmap="gray") - # plt.show() - return k - - -def kernelFromTrajectory(x): - h = 5 - log(rand()) / 0.15 - h = round(min([h, 27])).astype(int) - h = h + 1 - h % 2 - w = h - k = zeros((h, w)) - - xmin = min(x[0]) - xmax = max(x[0]) - ymin = min(x[1]) - ymax = max(x[1]) - xthr = arange(xmin, xmax, (xmax - xmin) / w) - ythr = arange(ymin, ymax, (ymax - ymin) / h) - - for i in range(1, xthr.size): - for j in range(1, ythr.size): - idx = ( - (x[0, :] >= xthr[i - 1]) - & (x[0, :] < xthr[i]) - & (x[1, :] >= ythr[j - 1]) - & (x[1, :] < ythr[j]) - ) - k[i - 1, j - 1] = sum(idx) - if sum(k) == 0: - return - k = k / sum(k) - k = convolve2d(k, fspecial_gauss(3, 1), "same") - k = k / sum(k) - return k - - -def randomTrajectory(T): - x = zeros((3, T)) - v = randn(3, T) - r = zeros((3, T)) - trv = 1 / 1 - trr = 2 * pi / T - for t in range(1, T): - F_rot = randn(3) / (t + 1) + r[:, t - 1] - F_trans = randn(3) / (t + 1) - r[:, t] = r[:, t - 1] + trr * F_rot - v[:, t] = v[:, t - 1] + trv * F_trans - st = v[:, t] - st = rot3D(st, r[:, t]) - x[:, t] = x[:, t - 1] + st - return x - - -def rot3D(x, r): - Rx = array([[1, 0, 0], [0, cos(r[0]), -sin(r[0])], [0, sin(r[0]), cos(r[0])]]) - Ry = array([[cos(r[1]), 0, sin(r[1])], [0, 1, 0], [-sin(r[1]), 0, cos(r[1])]]) - Rz = array([[cos(r[2]), -sin(r[2]), 0], [sin(r[2]), cos(r[2]), 0], [0, 0, 1]]) - R = Rz @ Ry @ Rx - x = R @ x - return x - - -if __name__ == '__main__': - a = opt_fft_size([111]) - print(a) - - print(fspecial('gaussian', 5, 1)) - - print(p2o(torch.zeros(1,1,4,4).float(),(14,14)).shape) - - k = blurkernel_synthesis(11) - import matplotlib.pyplot as plt - plt.imshow(k, interpolation="nearest", cmap="gray") - plt.show() diff --git a/core/data/deg_kair_utils/utils_dist.py b/core/data/deg_kair_utils/utils_dist.py deleted file mode 100644 index 88811737a8fc7cb6e12d9226a9242dbf8391d86b..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_dist.py +++ /dev/null @@ -1,201 +0,0 @@ -# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 -import functools -import os -import subprocess -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - - -# ---------------------------------- -# init -# ---------------------------------- -def init_dist(launcher, backend='nccl', **kwargs): - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, **kwargs) - else: - raise ValueError(f'Invalid launcher type: {launcher}') - - -def _init_dist_pytorch(backend, **kwargs): - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - - -def _init_dist_slurm(backend, port=None): - """Initialize slurm distributed training environment. - If argument ``port`` is not specified, then the master port will be system - environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system - environment variable, then a default port ``29500`` will be used. - Args: - backend (str): Backend of torch.distributed. - port (int, optional): Master port. Defaults to None. - """ - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(proc_id % num_gpus) - addr = subprocess.getoutput( - f'scontrol show hostname {node_list} | head -n1') - # specify master port - if port is not None: - os.environ['MASTER_PORT'] = str(port) - elif 'MASTER_PORT' in os.environ: - pass # use MASTER_PORT in the environment variable - else: - # 29500 is torch.distributed default port - os.environ['MASTER_PORT'] = '29500' - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) - os.environ['RANK'] = str(proc_id) - dist.init_process_group(backend=backend) - - - -# ---------------------------------- -# get rank and world_size -# ---------------------------------- -def get_dist_info(): - if dist.is_available(): - initialized = dist.is_initialized() - else: - initialized = False - if initialized: - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - return rank, world_size - - -def get_rank(): - if not dist.is_available(): - return 0 - - if not dist.is_initialized(): - return 0 - - return dist.get_rank() - - -def get_world_size(): - if not dist.is_available(): - return 1 - - if not dist.is_initialized(): - return 1 - - return dist.get_world_size() - - -def master_only(func): - - @functools.wraps(func) - def wrapper(*args, **kwargs): - rank, _ = get_dist_info() - if rank == 0: - return func(*args, **kwargs) - - return wrapper - - - - - - -# ---------------------------------- -# operation across ranks -# ---------------------------------- -def reduce_sum(tensor): - if not dist.is_available(): - return tensor - - if not dist.is_initialized(): - return tensor - - tensor = tensor.clone() - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - - return tensor - - -def gather_grad(params): - world_size = get_world_size() - - if world_size == 1: - return - - for param in params: - if param.grad is not None: - dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) - param.grad.data.div_(world_size) - - -def all_gather(data): - world_size = get_world_size() - - if world_size == 1: - return [data] - - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to('cuda') - - local_size = torch.IntTensor([tensor.numel()]).to('cuda') - size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] - dist.all_gather(size_list, local_size) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - tensor_list = [] - for _ in size_list: - tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) - - if local_size != max_size: - padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') - tensor = torch.cat((tensor, padding), 0) - - dist.all_gather(tensor_list, tensor) - - data_list = [] - - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - -def reduce_loss_dict(loss_dict): - world_size = get_world_size() - - if world_size < 2: - return loss_dict - - with torch.no_grad(): - keys = [] - losses = [] - - for k in sorted(loss_dict.keys()): - keys.append(k) - losses.append(loss_dict[k]) - - losses = torch.stack(losses, 0) - dist.reduce(losses, dst=0) - - if dist.get_rank() == 0: - losses /= world_size - - reduced_losses = {k: v for k, v in zip(keys, losses)} - - return reduced_losses - diff --git a/core/data/deg_kair_utils/utils_googledownload.py b/core/data/deg_kair_utils/utils_googledownload.py deleted file mode 100644 index 25533d4e0d90bac7519874a654ffd833d16ae289..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_googledownload.py +++ /dev/null @@ -1,93 +0,0 @@ -import math -import requests -from tqdm import tqdm - - -''' -borrowed from -https://github.com/xinntao/BasicSR/blob/28883e15eedc3381d23235ff3cf7c454c4be87e6/basicsr/utils/download_util.py -''' - - -def sizeof_fmt(size, suffix='B'): - """Get human readable file size. - Args: - size (int): File size. - suffix (str): Suffix. Default: 'B'. - Return: - str: Formated file siz. - """ - for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: - if abs(size) < 1024.0: - return f'{size:3.1f} {unit}{suffix}' - size /= 1024.0 - return f'{size:3.1f} Y{suffix}' - - -def download_file_from_google_drive(file_id, save_path): - """Download files from google drive. - Ref: - https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 - Args: - file_id (str): File id. - save_path (str): Save path. - """ - - session = requests.Session() - URL = 'https://docs.google.com/uc?export=download' - params = {'id': file_id} - - response = session.get(URL, params=params, stream=True) - token = get_confirm_token(response) - if token: - params['confirm'] = token - response = session.get(URL, params=params, stream=True) - - # get file size - response_file_size = session.get( - URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) - if 'Content-Range' in response_file_size.headers: - file_size = int( - response_file_size.headers['Content-Range'].split('/')[1]) - else: - file_size = None - - save_response_content(response, save_path, file_size) - - -def get_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith('download_warning'): - return value - return None - - -def save_response_content(response, - destination, - file_size=None, - chunk_size=32768): - if file_size is not None: - pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') - - readable_file_size = sizeof_fmt(file_size) - else: - pbar = None - - with open(destination, 'wb') as f: - downloaded_size = 0 - for chunk in response.iter_content(chunk_size): - downloaded_size += chunk_size - if pbar is not None: - pbar.update(1) - pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' - f'/ {readable_file_size}') - if chunk: # filter out keep-alive new chunks - f.write(chunk) - if pbar is not None: - pbar.close() - - -if __name__ == "__main__": - file_id = '1WNULM1e8gRNvsngVscsQ8tpaOqJ4mYtv' - save_path = 'BSRGAN.pth' - download_file_from_google_drive(file_id, save_path) diff --git a/core/data/deg_kair_utils/utils_image.py b/core/data/deg_kair_utils/utils_image.py deleted file mode 100644 index 0e513a8bc1594c9ce2ba47ce3fe3b497269b7f16..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_image.py +++ /dev/null @@ -1,1016 +0,0 @@ -import os -import math -import random -import numpy as np -import torch -import cv2 -from torchvision.utils import make_grid -from datetime import datetime -# import torchvision.transforms as transforms -import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D -os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -# https://github.com/twhui/SRGAN-pyTorch -# https://github.com/xinntao/BasicSR -# -------------------------------------------- -''' - - -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') - - -def imshow(x, title=None, cbar=False, figsize=None): - plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') - if title: - plt.title(title) - if cbar: - plt.colorbar() - plt.show() - - -def surf(Z, cmap='rainbow', figsize=None): - plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') - - w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) - X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) - plt.show() - - -''' -# -------------------------------------------- -# get image pathes -# -------------------------------------------- -''' - - -def get_image_paths(dataroot): - paths = None # return None if dataroot is None - if isinstance(dataroot, str): - paths = sorted(_get_paths_from_images(dataroot)) - elif isinstance(dataroot, list): - paths = [] - for i in dataroot: - paths += sorted(_get_paths_from_images(i)) - return paths - - -def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) - images = [] - for dirpath, _, fnames in sorted(os.walk(path)): - for fname in sorted(fnames): - if is_image_file(fname): - img_path = os.path.join(dirpath, fname) - images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) - return images - - -''' -# -------------------------------------------- -# split large images into small images -# -------------------------------------------- -''' - - -def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): - w, h = img.shape[:2] - patches = [] - if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) - # print(w1) - # print(h1) - for i in w1: - for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) - else: - patches.append(img) - - return patches - - -def imssave(imgs, img_path): - """ - imgs: list, N images of size WxHxC - """ - img_name, ext = os.path.splitext(os.path.basename(img_path)) - for i, img in enumerate(imgs): - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png') - cv2.imwrite(new_path, img) - - -def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800): - """ - split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), - and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) - will be splitted. - - Args: - original_dataroot: - taget_dataroot: - p_size: size of small images - p_overlap: patch size in training is a good choice - p_max: images with smaller size than (p_max)x(p_max) keep unchanged. - """ - paths = get_image_paths(original_dataroot) - for img_path in paths: - # img_name, ext = os.path.splitext(os.path.basename(img_path)) - img = imread_uint(img_path, n_channels=n_channels) - patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path - -''' -# -------------------------------------------- -# makedir -# -------------------------------------------- -''' - - -def mkdir(path): - if not os.path.exists(path): - os.makedirs(path) - - -def mkdirs(paths): - if isinstance(paths, str): - mkdir(paths) - else: - for path in paths: - mkdir(path) - - -def mkdir_and_rename(path): - if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) - os.rename(path, new_name) - os.makedirs(path) - - -''' -# -------------------------------------------- -# read image from path -# opencv is fast, but read BGR numpy image -# -------------------------------------------- -''' - - -# -------------------------------------------- -# get uint8 image of size HxWxn_channles (RGB) -# -------------------------------------------- -def imread_uint(path, n_channels=3): - # input: path - # output: HxWx3(RGB or GGG), or HxWx1 (G) - if n_channels == 1: - img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE - img = np.expand_dims(img, axis=2) # HxWx1 - elif n_channels == 3: - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G - if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG - else: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB - return img - - -# -------------------------------------------- -# matlab's imwrite -# -------------------------------------------- -def imsave(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - -def imwrite(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - - - -# -------------------------------------------- -# get single image of size HxWxn_channles (BGR) -# -------------------------------------------- -def read_img(path): - # read image by cv2 - # return: Numpy float32, HWC, BGR, [0,1] - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - -''' -# -------------------------------------------- -# image format conversion -# -------------------------------------------- -# numpy(single) <---> numpy(uint) -# numpy(single) <---> tensor -# numpy(uint) <---> tensor -# -------------------------------------------- -''' - - -# -------------------------------------------- -# numpy(single) [0, 1] <---> numpy(uint) -# -------------------------------------------- - - -def uint2single(img): - - return np.float32(img/255.) - - -def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) - - -def uint162single(img): - - return np.float32(img/65535.) - - -def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) - - -# -------------------------------------------- -# numpy(uint) (HxWxC or HxW) <---> tensor -# -------------------------------------------- - - -# convert uint to 4-dimensional torch tensor -def uint2tensor4(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) - - -# convert uint to 3-dimensional torch tensor -def uint2tensor3(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) - - -# convert 2/3/4-dimensional torch tensor to uint -def tensor2uint(img): - img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) - - -# -------------------------------------------- -# numpy(single) (HxWxC) <---> tensor -# -------------------------------------------- - - -# convert single (HxWxC) to 3-dimensional torch tensor -def single2tensor3(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() - - -# convert single (HxWxC) to 4-dimensional torch tensor -def single2tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) - - -# convert torch tensor to single -def tensor2single(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - - return img - -# convert torch tensor to single -def tensor2single3(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - elif img.ndim == 2: - img = np.expand_dims(img, axis=2) - return img - - -def single2tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) - - -def single32tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) - - -def single42tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() - - -# from skimage.io import imread, imsave -def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' - Converts a torch Tensor into an image Numpy array of BGR channel order - Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order - Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' - tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp - tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] - n_dim = tensor.dim() - if n_dim == 4: - n_img = len(tensor) - img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 3: - img_np = tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 2: - img_np = tensor.numpy() - else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) - if out_type == np.uint8: - img_np = (img_np * 255.0).round() - # Important. Unlike matlab, numpy.uint8() WILL NOT round by default. - return img_np.astype(out_type) - - -''' -# -------------------------------------------- -# Augmentation, flipe and/or rotate -# -------------------------------------------- -# The following two are enough. -# (1) augmet_img: numpy image of WxHxC or WxH -# (2) augment_img_tensor4: tensor image 1xCxWxH -# -------------------------------------------- -''' - - -def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return np.flipud(np.rot90(img)) - elif mode == 2: - return np.flipud(img) - elif mode == 3: - return np.rot90(img, k=3) - elif mode == 4: - return np.flipud(np.rot90(img, k=2)) - elif mode == 5: - return np.rot90(img) - elif mode == 6: - return np.rot90(img, k=2) - elif mode == 7: - return np.flipud(np.rot90(img, k=3)) - - -def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return img.rot90(1, [2, 3]).flip([2]) - elif mode == 2: - return img.flip([2]) - elif mode == 3: - return img.rot90(3, [2, 3]) - elif mode == 4: - return img.rot90(2, [2, 3]).flip([2]) - elif mode == 5: - return img.rot90(1, [2, 3]) - elif mode == 6: - return img.rot90(2, [2, 3]) - elif mode == 7: - return img.rot90(3, [2, 3]).flip([2]) - - -def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - img_size = img.size() - img_np = img.data.cpu().numpy() - if len(img_size) == 3: - img_np = np.transpose(img_np, (1, 2, 0)) - elif len(img_size) == 4: - img_np = np.transpose(img_np, (2, 3, 1, 0)) - img_np = augment_img(img_np, mode=mode) - img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) - if len(img_size) == 3: - img_tensor = img_tensor.permute(2, 0, 1) - elif len(img_size) == 4: - img_tensor = img_tensor.permute(3, 2, 0, 1) - - return img_tensor.type_as(img) - - -def augment_img_np3(img, mode=0): - if mode == 0: - return img - elif mode == 1: - return img.transpose(1, 0, 2) - elif mode == 2: - return img[::-1, :, :] - elif mode == 3: - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 4: - return img[:, ::-1, :] - elif mode == 5: - img = img[:, ::-1, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 6: - img = img[:, ::-1, :] - img = img[::-1, :, :] - return img - elif mode == 7: - img = img[:, ::-1, :] - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - - -def augment_imgs(img_list, hflip=True, rot=True): - # horizontal flip OR rotate - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - -''' -# -------------------------------------------- -# modcrop and shave -# -------------------------------------------- -''' - - -def modcrop(img_in, scale): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - if img.ndim == 2: - H, W = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] - elif img.ndim == 3: - H, W, C = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] - else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) - return img - - -def shave(img_in, border=0): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - h, w = img.shape[:2] - img = img[border:h-border, border:w-border] - return img - - -''' -# -------------------------------------------- -# image processing process on numpy image -# channel_convert(in_c, tar_type, img_list): -# rgb2ycbcr(img, only_y=True): -# bgr2ycbcr(img, only_y=True): -# ycbcr2rgb(img): -# -------------------------------------------- -''' - - -def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] - rlt = np.clip(rlt, 0, 255) - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def channel_convert(in_c, tar_type, img_list): - # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray - gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] - return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y - y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] - return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR - return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] - else: - return img_list - - -''' -# -------------------------------------------- -# metric, PSNR, SSIM and PSNRB -# -------------------------------------------- -''' - - -# -------------------------------------------- -# PSNR -# -------------------------------------------- -def calculate_psnr(img1, img2, border=0): - # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) - - -# -------------------------------------------- -# SSIM -# -------------------------------------------- -def calculate_ssim(img1, img2, border=0): - '''calculate SSIM - the same outputs as MATLAB's - img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - if img1.ndim == 2: - return ssim(img1, img2) - elif img1.ndim == 3: - if img1.shape[2] == 3: - ssims = [] - for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) - return np.array(ssims).mean() - elif img1.shape[2] == 1: - return ssim(np.squeeze(img1), np.squeeze(img2)) - else: - raise ValueError('Wrong input image dimensions.') - - -def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - kernel = cv2.getGaussianKernel(11, 1.5) - window = np.outer(kernel, kernel.transpose()) - - mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid - mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] - mu1_sq = mu1**2 - mu2_sq = mu2**2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq - sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq - sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return ssim_map.mean() - - -def _blocking_effect_factor(im): - block_size = 8 - - block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8) - block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8) - - horizontal_block_difference = ( - (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum( - 3).sum(2).sum(1) - vertical_block_difference = ( - (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum( - 2).sum(1) - - nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions) - nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions) - - horizontal_nonblock_difference = ( - (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum( - 3).sum(2).sum(1) - vertical_nonblock_difference = ( - (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum( - 3).sum(2).sum(1) - - n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1) - n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1) - boundary_difference = (horizontal_block_difference + vertical_block_difference) / ( - n_boundary_horiz + n_boundary_vert) - - n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz - n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert - nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / ( - n_nonboundary_horiz + n_nonboundary_vert) - - scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]])) - bef = scaler * (boundary_difference - nonboundary_difference) - - bef[boundary_difference <= nonboundary_difference] = 0 - return bef - - -def calculate_psnrb(img1, img2, border=0): - """Calculate PSNR-B (Peak Signal-to-Noise Ratio). - Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation - # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py - Args: - img1 (ndarray): Images with range [0, 255]. - img2 (ndarray): Images with range [0, 255]. - border (int): Cropped pixels in each edge of an image. These - pixels are not involved in the PSNR calculation. - test_y_channel (bool): Test on Y channel of YCbCr. Default: False. - Returns: - float: psnr result. - """ - - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - - if img1.ndim == 2: - img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2) - - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - - # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py - img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255. - img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255. - - total = 0 - for c in range(img1.shape[1]): - mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none') - bef = _blocking_effect_factor(img1[:, c:c + 1, :, :]) - - mse = mse.view(mse.shape[0], -1).mean(1) - total += 10 * torch.log10(1 / (mse + bef)) - - return float(total) / img1.shape[1] - -''' -# -------------------------------------------- -# matlab's bicubic imresize (numpy and torch) [0, 1] -# -------------------------------------------- -''' - - -# matlab 'imresize' function, now only support 'bicubic' -def cubic(x): - absx = torch.abs(x) - absx2 = absx**2 - absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) - - -def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): - if (scale < 1) and (antialiasing): - # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width - kernel_width = kernel_width / scale - - # Output-space coordinates - x = torch.linspace(1, out_length, out_length) - - # Input-space coordinates. Calculate the inverse mapping such that 0.5 - # in output space maps to 0.5 in input space, and 0.5+scale in output - # space maps to 1.5 in input space. - u = x / scale + 0.5 * (1 - 1 / scale) - - # What is the left-most pixel that can be involved in the computation? - left = torch.floor(u - kernel_width / 2) - - # What is the maximum number of pixels that can be involved in the - # computation? Note: it's OK to use an extra pixel here; if the - # corresponding weights are all zero, it will be eliminated at the end - # of this function. - P = math.ceil(kernel_width) + 2 - - # The indices of the input pixels involved in computing the k-th output - # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) - - # The weights used to compute the k-th output pixel are in row k of the - # weights matrix. - distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices - # apply cubic kernel - if (scale < 1) and (antialiasing): - weights = scale * cubic(distance_to_center * scale) - else: - weights = cubic(distance_to_center) - # Normalize the weights matrix so that each row sums to 1. - weights_sum = torch.sum(weights, 1).view(out_length, 1) - weights = weights / weights_sum.expand(out_length, P) - - # If a column in weights is all zero, get rid of it. only consider the first and last column. - weights_zero_tmp = torch.sum((weights == 0), 0) - if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): - indices = indices.narrow(1, 1, P - 2) - weights = weights.narrow(1, 1, P - 2) - if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): - indices = indices.narrow(1, 0, P - 2) - weights = weights.narrow(1, 0, P - 2) - weights = weights.contiguous() - indices = indices.contiguous() - sym_len_s = -indices.min() + 1 - sym_len_e = indices.max() - in_length - indices = indices + sym_len_s - 1 - return weights, indices, int(sym_len_s), int(sym_len_e) - - -# -------------------------------------------- -# imresize for tensor image [0, 1] -# -------------------------------------------- -def imresize(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: pytorch tensor, CHW or HW [0,1] - # output: CHW or HW [0,1] w/o round - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(0) - in_C, in_H, in_W = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) - img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:, :sym_len_Hs, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[:, -sym_len_He:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(in_C, out_H, in_W) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) - out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :, :sym_len_Ws] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, :, -sym_len_We:] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(in_C, out_H, out_W) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - return out_2 - - -# -------------------------------------------- -# imresize for numpy image [0, 1] -# -------------------------------------------- -def imresize_np(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: Numpy, HWC or HW [0,1] - # output: HWC or HW [0,1] w/o round - img = torch.from_numpy(img) - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(2) - - in_H, in_W, in_C = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) - img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:sym_len_Hs, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[-sym_len_He:, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(out_H, in_W, in_C) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) - out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :sym_len_Ws, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, -sym_len_We:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(out_H, out_W, in_C) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - - return out_2.numpy() - - -if __name__ == '__main__': - img = imread_uint('test.bmp', 3) -# img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) -# imshow(single2uint(img_bicubic)) -# -# img_tensor = single2tensor4(img) -# for i in range(8): -# imshow(np.concatenate((augment_img(img, i), tensor2single(augment_img_tensor4(img_tensor, i))), 1)) - -# patches = patches_from_image(img, p_size=128, p_overlap=0, p_max=200) -# imssave(patches,'a.png') - - - - - - - diff --git a/core/data/deg_kair_utils/utils_lmdb.py b/core/data/deg_kair_utils/utils_lmdb.py deleted file mode 100644 index 75192c346bb9c0b96f8b09635ed548bd6e797d89..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_lmdb.py +++ /dev/null @@ -1,205 +0,0 @@ -import cv2 -import lmdb -import sys -from multiprocessing import Pool -from os import path as osp -from tqdm import tqdm - - -def make_lmdb_from_imgs(data_path, - lmdb_path, - img_path_list, - keys, - batch=5000, - compress_level=1, - multiprocessing_read=False, - n_thread=40, - map_size=None): - """Make lmdb from images. - - Contents of lmdb. The file structure is: - example.lmdb - ├── data.mdb - ├── lock.mdb - ├── meta_info.txt - - The data.mdb and lock.mdb are standard lmdb files and you can refer to - https://lmdb.readthedocs.io/en/release/ for more details. - - The meta_info.txt is a specified txt file to record the meta information - of our datasets. It will be automatically created when preparing - datasets by our provided dataset tools. - Each line in the txt file records 1)image name (with extension), - 2)image shape, and 3)compression level, separated by a white space. - - For example, the meta information could be: - `000_00000000.png (720,1280,3) 1`, which means: - 1) image name (with extension): 000_00000000.png; - 2) image shape: (720,1280,3); - 3) compression level: 1 - - We use the image name without extension as the lmdb key. - - If `multiprocessing_read` is True, it will read all the images to memory - using multiprocessing. Thus, your server needs to have enough memory. - - Args: - data_path (str): Data path for reading images. - lmdb_path (str): Lmdb save path. - img_path_list (str): Image path list. - keys (str): Used for lmdb keys. - batch (int): After processing batch images, lmdb commits. - Default: 5000. - compress_level (int): Compress level when encoding images. Default: 1. - multiprocessing_read (bool): Whether use multiprocessing to read all - the images to memory. Default: False. - n_thread (int): For multiprocessing. - map_size (int | None): Map size for lmdb env. If None, use the - estimated size from images. Default: None - """ - - assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' - f'but got {len(img_path_list)} and {len(keys)}') - print(f'Create lmdb for {data_path}, save to {lmdb_path}...') - print(f'Totoal images: {len(img_path_list)}') - if not lmdb_path.endswith('.lmdb'): - raise ValueError("lmdb_path must end with '.lmdb'.") - if osp.exists(lmdb_path): - print(f'Folder {lmdb_path} already exists. Exit.') - sys.exit(1) - - if multiprocessing_read: - # read all the images to memory (multiprocessing) - dataset = {} # use dict to keep the order for multiprocessing - shapes = {} - print(f'Read images with multiprocessing, #thread: {n_thread} ...') - pbar = tqdm(total=len(img_path_list), unit='image') - - def callback(arg): - """get the image data and update pbar.""" - key, dataset[key], shapes[key] = arg - pbar.update(1) - pbar.set_description(f'Read {key}') - - pool = Pool(n_thread) - for path, key in zip(img_path_list, keys): - pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) - pool.close() - pool.join() - pbar.close() - print(f'Finish reading {len(img_path_list)} images.') - - # create lmdb environment - if map_size is None: - # obtain data size for one image - img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) - _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) - data_size_per_img = img_byte.nbytes - print('Data size per image is: ', data_size_per_img) - data_size = data_size_per_img * len(img_path_list) - map_size = data_size * 10 - - env = lmdb.open(lmdb_path, map_size=map_size) - - # write data to lmdb - pbar = tqdm(total=len(img_path_list), unit='chunk') - txn = env.begin(write=True) - txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') - for idx, (path, key) in enumerate(zip(img_path_list, keys)): - pbar.update(1) - pbar.set_description(f'Write {key}') - key_byte = key.encode('ascii') - if multiprocessing_read: - img_byte = dataset[key] - h, w, c = shapes[key] - else: - _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) - h, w, c = img_shape - - txn.put(key_byte, img_byte) - # write meta information - txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') - if idx % batch == 0: - txn.commit() - txn = env.begin(write=True) - pbar.close() - txn.commit() - env.close() - txt_file.close() - print('\nFinish writing lmdb.') - - -def read_img_worker(path, key, compress_level): - """Read image worker. - - Args: - path (str): Image path. - key (str): Image key. - compress_level (int): Compress level when encoding images. - - Returns: - str: Image key. - byte: Image byte. - tuple[int]: Image shape. - """ - - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - # deal with `libpng error: Read Error` - if img is None: - print(f'To deal with `libpng error: Read Error`, use PIL to load {path}') - from PIL import Image - import numpy as np - img = Image.open(path) - img = np.asanyarray(img) - img = img[:, :, [2, 1, 0]] - - if img.ndim == 2: - h, w = img.shape - c = 1 - else: - h, w, c = img.shape - _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) - return (key, img_byte, (h, w, c)) - - -class LmdbMaker(): - """LMDB Maker. - - Args: - lmdb_path (str): Lmdb save path. - map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. - batch (int): After processing batch images, lmdb commits. - Default: 5000. - compress_level (int): Compress level when encoding images. Default: 1. - """ - - def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): - if not lmdb_path.endswith('.lmdb'): - raise ValueError("lmdb_path must end with '.lmdb'.") - if osp.exists(lmdb_path): - print(f'Folder {lmdb_path} already exists. Exit.') - sys.exit(1) - - self.lmdb_path = lmdb_path - self.batch = batch - self.compress_level = compress_level - self.env = lmdb.open(lmdb_path, map_size=map_size) - self.txn = self.env.begin(write=True) - self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') - self.counter = 0 - - def put(self, img_byte, key, img_shape): - self.counter += 1 - key_byte = key.encode('ascii') - self.txn.put(key_byte, img_byte) - # write meta information - h, w, c = img_shape - self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') - if self.counter % self.batch == 0: - self.txn.commit() - self.txn = self.env.begin(write=True) - - def close(self): - self.txn.commit() - self.env.close() - self.txt_file.close() diff --git a/core/data/deg_kair_utils/utils_logger.py b/core/data/deg_kair_utils/utils_logger.py deleted file mode 100644 index 3067190e1b09b244814e0ccc4496b18f06e22b54..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_logger.py +++ /dev/null @@ -1,66 +0,0 @@ -import sys -import datetime -import logging - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -# https://github.com/xinntao/BasicSR -# -------------------------------------------- -''' - - -def log(*args, **kwargs): - print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) - - -''' -# -------------------------------------------- -# logger -# -------------------------------------------- -''' - - -def logger_info(logger_name, log_path='default_logger.log'): - ''' set up logger - modified by Kai Zhang (github: https://github.com/cszn) - ''' - log = logging.getLogger(logger_name) - if log.hasHandlers(): - print('LogHandlers exist!') - else: - print('LogHandlers setup!') - level = logging.INFO - formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') - fh = logging.FileHandler(log_path, mode='a') - fh.setFormatter(formatter) - log.setLevel(level) - log.addHandler(fh) - # print(len(log.handlers)) - - sh = logging.StreamHandler() - sh.setFormatter(formatter) - log.addHandler(sh) - - -''' -# -------------------------------------------- -# print to file and std_out simultaneously -# -------------------------------------------- -''' - - -class logger_print(object): - def __init__(self, log_path="default.log"): - self.terminal = sys.stdout - self.log = open(log_path, 'a') - - def write(self, message): - self.terminal.write(message) - self.log.write(message) # write the message - - def flush(self): - pass diff --git a/core/data/deg_kair_utils/utils_mat.py b/core/data/deg_kair_utils/utils_mat.py deleted file mode 100644 index cd25d500c0eae77a3b815b8e956205b737ee43d4..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_mat.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import json -import scipy.io as spio -import pandas as pd - - -def loadmat(filename): - ''' - this function should be called instead of direct spio.loadmat - as it cures the problem of not properly recovering python dictionaries - from mat files. It calls the function check keys to cure all entries - which are still mat-objects - ''' - data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True) - return dict_to_nonedict(_check_keys(data)) - -def _check_keys(dict): - ''' - checks if entries in dictionary are mat-objects. If yes - todict is called to change them to nested dictionaries - ''' - for key in dict: - if isinstance(dict[key], spio.matlab.mio5_params.mat_struct): - dict[key] = _todict(dict[key]) - return dict - -def _todict(matobj): - ''' - A recursive function which constructs from matobjects nested dictionaries - ''' - dict = {} - for strg in matobj._fieldnames: - elem = matobj.__dict__[strg] - if isinstance(elem, spio.matlab.mio5_params.mat_struct): - dict[strg] = _todict(elem) - else: - dict[strg] = elem - return dict - - -def dict_to_nonedict(opt): - if isinstance(opt, dict): - new_opt = dict() - for key, sub_opt in opt.items(): - new_opt[key] = dict_to_nonedict(sub_opt) - return NoneDict(**new_opt) - elif isinstance(opt, list): - return [dict_to_nonedict(sub_opt) for sub_opt in opt] - else: - return opt - - -class NoneDict(dict): - def __missing__(self, key): - return None - - -def mat2json(mat_path=None, filepath = None): - """ - Converts .mat file to .json and writes new file - Parameters - ---------- - mat_path: Str - path/filename .mat存放路径 - filepath: Str - 如果需è¦ä¿å­˜æˆjson, 添加这一路径. å¦åˆ™ä¸ä¿å­˜ - Returns - 返回转化的字典 - ------- - None - Examples - -------- - >>> mat2json(blah blah) - """ - - matlabFile = loadmat(mat_path) - #pop all those dumb fields that don't let you jsonize file - matlabFile.pop('__header__') - matlabFile.pop('__version__') - matlabFile.pop('__globals__') - #jsonize the file - orientation is 'index' - matlabFile = pd.Series(matlabFile).to_json() - - if filepath: - json_path = os.path.splitext(os.path.split(mat_path)[1])[0] + '.json' - with open(json_path, 'w') as f: - f.write(matlabFile) - return matlabFile \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_matconvnet.py b/core/data/deg_kair_utils/utils_matconvnet.py deleted file mode 100644 index 506dc47805ae07976022b236ca64c98e9a6f78b3..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_matconvnet.py +++ /dev/null @@ -1,197 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -from collections import OrderedDict - -# import scipy.io as io -import hdf5storage - -""" -# -------------------------------------------- -# Convert matconvnet SimpleNN model into pytorch model -# -------------------------------------------- -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# 28/Nov/2019 -# -------------------------------------------- -""" - - -def weights2tensor(x, squeeze=False, in_features=None, out_features=None): - """Modified version of https://github.com/albanie/pytorch-mcn - Adjust memory layout and load weights as torch tensor - Args: - x (ndaray): a numpy array, corresponding to a set of network weights - stored in column major order - squeeze (bool) [False]: whether to squeeze the tensor (i.e. remove - singletons from the trailing dimensions. So after converting to - pytorch layout (C_out, C_in, H, W), if the shape is (A, B, 1, 1) - it will be reshaped to a matrix with shape (A,B). - in_features (int :: None): used to reshape weights for a linear block. - out_features (int :: None): used to reshape weights for a linear block. - Returns: - torch.tensor: a permuted sets of weights, matching the pytorch layout - convention - """ - if x.ndim == 4: - x = x.transpose((3, 2, 0, 1)) -# for FFDNet, pixel-shuffle layer -# if x.shape[1]==13: -# x=x[:,[0,2,1,3, 4,6,5,7, 8,10,9,11, 12],:,:] -# if x.shape[0]==12: -# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] -# if x.shape[1]==5: -# x=x[:,[0,2,1,3, 4],:,:] -# if x.shape[0]==4: -# x=x[[0,2,1,3],:,:,:] -## for SRMD, pixel-shuffle layer -# if x.shape[0]==12: -# x=x[[0,2,1,3, 4,6,5,7, 8,10,9,11],:,:,:] -# if x.shape[0]==27: -# x=x[[0,3,6,1,4,7,2,5,8, 0+9,3+9,6+9,1+9,4+9,7+9,2+9,5+9,8+9, 0+18,3+18,6+18,1+18,4+18,7+18,2+18,5+18,8+18],:,:,:] -# if x.shape[0]==48: -# x=x[[0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15, 0+16,4+16,8+16,12+16,1+16,5+16,9+16,13+16,2+16,6+16,10+16,14+16,3+16,7+16,11+16,15+16, 0+32,4+32,8+32,12+32,1+32,5+32,9+32,13+32,2+32,6+32,10+32,14+32,3+32,7+32,11+32,15+32],:,:,:] - - elif x.ndim == 3: # add by Kai - x = x[:,:,:,None] - x = x.transpose((3, 2, 0, 1)) - elif x.ndim == 2: - if x.shape[1] == 1: - x = x.flatten() - if squeeze: - if in_features and out_features: - x = x.reshape((out_features, in_features)) - x = np.squeeze(x) - return torch.from_numpy(np.ascontiguousarray(x)) - - -def save_model(network, save_path): - state_dict = network.state_dict() - for key, param in state_dict.items(): - state_dict[key] = param.cpu() - torch.save(state_dict, save_path) - - -if __name__ == '__main__': - - -# from utils import utils_logger -# import logging -# utils_logger.logger_info('a', 'a.log') -# logger = logging.getLogger('a') -# - # mcn = hdf5storage.loadmat('/model_zoo/matfile/FFDNet_Clip_gray.mat') - mcn = hdf5storage.loadmat('models/modelcolor.mat') - - - #logger.info(mcn['CNNdenoiser'][0][0][0][1][0][0][0][0]) - - mat_net = OrderedDict() - for idx in range(25): - mat_net[str(idx)] = OrderedDict() - count = -1 - - print(idx) - for i in range(13): - - if mcn['CNNdenoiser'][0][idx][0][i][0][0][0][0] == 'conv': - - count += 1 - w = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][0] - # print(w.shape) - w = weights2tensor(w) - # print(w.shape) - - b = mcn['CNNdenoiser'][0][idx][0][i][0][1][0][1] - b = weights2tensor(b) - print(b.shape) - - mat_net[str(idx)]['model.{:d}.weight'.format(count*2)] = w - mat_net[str(idx)]['model.{:d}.bias'.format(count*2)] = b - - torch.save(mat_net, 'model_zoo/modelcolor.pth') - - - -# from models.network_dncnn import IRCNN as net -# network = net(in_nc=3, out_nc=3, nc=64) -# state_dict = network.state_dict() -# -# #show_kv(state_dict) -# -# for i in range(len(mcn['net'][0][0][0])): -# print(mcn['net'][0][0][0][i][0][0][0][0]) -# -# count = -1 -# mat_net = OrderedDict() -# for i in range(len(mcn['net'][0][0][0])): -# if mcn['net'][0][0][0][i][0][0][0][0] == 'conv': -# -# count += 1 -# w = mcn['net'][0][0][0][i][0][1][0][0] -# print(w.shape) -# w = weights2tensor(w) -# print(w.shape) -# -# b = mcn['net'][0][0][0][i][0][1][0][1] -# b = weights2tensor(b) -# print(b.shape) -# -# mat_net['model.{:d}.weight'.format(count*2)] = w -# mat_net['model.{:d}.bias'.format(count*2)] = b -# -# torch.save(mat_net, 'E:/pytorch/KAIR_ongoing/model_zoo/ffdnet_gray_clip.pth') -# -# -# -# crt_net = torch.load('E:/pytorch/KAIR_ongoing/model_zoo/imdn_x4.pth') -# def show_kv(net): -# for k, v in net.items(): -# print(k) -# -# show_kv(crt_net) - - -# from models.network_dncnn import DnCNN as net -# network = net(in_nc=2, out_nc=1, nc=64, nb=20, act_mode='R') - -# from models.network_srmd import SRMD as net -# #network = net(in_nc=1, out_nc=1, nc=64, nb=15, act_mode='R') -# network = net(in_nc=19, out_nc=3, nc=128, nb=12, upscale=4, act_mode='R', upsample_mode='pixelshuffle') -# -# from models.network_rrdb import RRDB as net -# network = net(in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv') -# -# state_dict = network.state_dict() -# for key, param in state_dict.items(): -# print(key) -# from models.network_imdn import IMDN as net -# network = net(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4, act_mode='L', upsample_mode='pixelshuffle') -# state_dict = network.state_dict() -# mat_net = OrderedDict() -# for ((key, param),(key2, param2)) in zip(state_dict.items(), crt_net.items()): -# mat_net[key] = param2 -# torch.save(mat_net, 'model_zoo/imdn_x4_1.pth') -# - -# net_old = torch.load('net_old.pth') -# def show_kv(net): -# for k, v in net.items(): -# print(k) -# -# show_kv(net_old) -# from models.network_dpsr import MSRResNet_prior as net -# model = net(in_nc=4, out_nc=3, nc=96, nb=16, upscale=4, act_mode='R', upsample_mode='pixelshuffle') -# state_dict = network.state_dict() -# net_new = OrderedDict() -# for ((key, param),(key_old, param_old)) in zip(state_dict.items(), net_old.items()): -# net_new[key] = param_old -# torch.save(net_new, 'net_new.pth') - - - # print(key) - # print(param.size()) - - - - # run utils/utils_matconvnet.py diff --git a/core/data/deg_kair_utils/utils_model.py b/core/data/deg_kair_utils/utils_model.py deleted file mode 100644 index a4d9e6ac651784c7ed36e623c3a6175883123c2b..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_model.py +++ /dev/null @@ -1,330 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -from utils import utils_image as util -import re -import glob -import os - - -''' -# -------------------------------------------- -# Model -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -''' - - -def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): - """ - # --------------------------------------- - # Kai Zhang (github: https://github.com/cszn) - # 03/Mar/2019 - # --------------------------------------- - Args: - save_dir: model folder - net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' - pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path - - Return: - init_iter: iteration number - init_path: model path - # --------------------------------------- - """ - - file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) - if file_list: - iter_exist = [] - for file_ in file_list: - iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) - iter_exist.append(int(iter_current[0])) - init_iter = max(iter_exist) - init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) - else: - init_iter = 0 - init_path = pretrained_path - return init_iter, init_path - - -def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): - ''' - # --------------------------------------- - # Kai Zhang (github: https://github.com/cszn) - # 03/Mar/2019 - # --------------------------------------- - Args: - model: trained model - L: input Low-quality image - mode: - (0) normal: test(model, L) - (1) pad: test_pad(model, L, modulo=16) - (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) - (3) x8: test_x8(model, L, modulo=1) ^_^ - (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) - refield: effective receptive filed of the network, 32 is enough - useful when split, i.e., mode=2, 4 - min_size: min_sizeXmin_size image, e.g., 256X256 image - useful when split, i.e., mode=2, 4 - sf: scale factor for super-resolution, otherwise 1 - modulo: 1 if split - useful when pad, i.e., mode=1 - - Returns: - E: estimated image - # --------------------------------------- - ''' - if mode == 0: - E = test(model, L) - elif mode == 1: - E = test_pad(model, L, modulo, sf) - elif mode == 2: - E = test_split(model, L, refield, min_size, sf, modulo) - elif mode == 3: - E = test_x8(model, L, modulo, sf) - elif mode == 4: - E = test_split_x8(model, L, refield, min_size, sf, modulo) - return E - - -''' -# -------------------------------------------- -# normal (0) -# -------------------------------------------- -''' - - -def test(model, L): - E = model(L) - return E - - -''' -# -------------------------------------------- -# pad (1) -# -------------------------------------------- -''' - - -def test_pad(model, L, modulo=16, sf=1): - h, w = L.size()[-2:] - paddingBottom = int(np.ceil(h/modulo)*modulo-h) - paddingRight = int(np.ceil(w/modulo)*modulo-w) - L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) - E = model(L) - E = E[..., :h*sf, :w*sf] - return E - - -''' -# -------------------------------------------- -# split (function) -# -------------------------------------------- -''' - - -def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): - """ - Args: - model: trained model - L: input Low-quality image - refield: effective receptive filed of the network, 32 is enough - min_size: min_sizeXmin_size image, e.g., 256X256 image - sf: scale factor for super-resolution, otherwise 1 - modulo: 1 if split - - Returns: - E: estimated result - """ - h, w = L.size()[-2:] - if h*w <= min_size**2: - L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) - E = model(L) - E = E[..., :h*sf, :w*sf] - else: - top = slice(0, (h//2//refield+1)*refield) - bottom = slice(h - (h//2//refield+1)*refield, h) - left = slice(0, (w//2//refield+1)*refield) - right = slice(w - (w//2//refield+1)*refield, w) - Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] - - if h * w <= 4*(min_size**2): - Es = [model(Ls[i]) for i in range(4)] - else: - Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] - - b, c = Es[0].size()[:2] - E = torch.zeros(b, c, sf * h, sf * w).type_as(L) - - E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] - E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] - E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] - E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] - return E - - -''' -# -------------------------------------------- -# split (2) -# -------------------------------------------- -''' - - -def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): - E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) - return E - - -''' -# -------------------------------------------- -# x8 (3) -# -------------------------------------------- -''' - - -def test_x8(model, L, modulo=1, sf=1): - E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] - for i in range(len(E_list)): - if i == 3 or i == 5: - E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) - else: - E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) - output_cat = torch.stack(E_list, dim=0) - E = output_cat.mean(dim=0, keepdim=False) - return E - - -''' -# -------------------------------------------- -# split and x8 (4) -# -------------------------------------------- -''' - - -def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): - E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] - for k, i in enumerate(range(len(E_list))): - if i==3 or i==5: - E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) - else: - E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) - output_cat = torch.stack(E_list, dim=0) - E = output_cat.mean(dim=0, keepdim=False) - return E - - -''' -# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- -# _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ -# ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- -''' - - -''' -# -------------------------------------------- -# print -# -------------------------------------------- -''' - - -# -------------------------------------------- -# print model -# -------------------------------------------- -def print_model(model): - msg = describe_model(model) - print(msg) - - -# -------------------------------------------- -# print params -# -------------------------------------------- -def print_params(model): - msg = describe_params(model) - print(msg) - - -''' -# -------------------------------------------- -# information -# -------------------------------------------- -''' - - -# -------------------------------------------- -# model inforation -# -------------------------------------------- -def info_model(model): - msg = describe_model(model) - return msg - - -# -------------------------------------------- -# params inforation -# -------------------------------------------- -def info_params(model): - msg = describe_params(model) - return msg - - -''' -# -------------------------------------------- -# description -# -------------------------------------------- -''' - - -# -------------------------------------------- -# model name and total number of parameters -# -------------------------------------------- -def describe_model(model): - if isinstance(model, torch.nn.DataParallel): - model = model.module - msg = '\n' - msg += 'models name: {}'.format(model.__class__.__name__) + '\n' - msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' - msg += 'Net structure:\n{}'.format(str(model)) + '\n' - return msg - - -# -------------------------------------------- -# parameters description -# -------------------------------------------- -def describe_params(model): - if isinstance(model, torch.nn.DataParallel): - model = model.module - msg = '\n' - msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' - for name, param in model.state_dict().items(): - if not 'num_batches_tracked' in name: - v = param.data.clone().float() - msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' - return msg - - -if __name__ == '__main__': - - class Net(torch.nn.Module): - def __init__(self, in_channels=3, out_channels=3): - super(Net, self).__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) - - def forward(self, x): - x = self.conv(x) - return x - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - - model = Net() - model = model.eval() - print_model(model) - print_params(model) - x = torch.randn((2,3,401,401)) - torch.cuda.empty_cache() - with torch.no_grad(): - for mode in range(5): - y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) - print(y.shape) - - # run utils/utils_model.py diff --git a/core/data/deg_kair_utils/utils_modelsummary.py b/core/data/deg_kair_utils/utils_modelsummary.py deleted file mode 100644 index 5e040e31d8ddffbb8b7b2e2dc4ddf0b9cdca6a23..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_modelsummary.py +++ /dev/null @@ -1,485 +0,0 @@ -import torch.nn as nn -import torch -import numpy as np - -''' ----- 1) FLOPs: floating point operations ----- 2) #Activations: the number of elements of all ‘Conv2d’ outputs ----- 3) #Conv2d: the number of ‘Conv2d’ layers -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 21/July/2020 -# -------------------------------------------- -# Reference -https://github.com/sovrasov/flops-counter.pytorch.git - -# If you use this code, please consider the following citation: - -@inproceedings{zhang2020aim, % - title={AIM 2020 Challenge on Efficient Super-Resolution: Methods and Results}, - author={Kai Zhang and Martin Danelljan and Yawei Li and Radu Timofte and others}, - booktitle={European Conference on Computer Vision Workshops}, - year={2020} -} -# -------------------------------------------- -''' - -def get_model_flops(model, input_res, print_per_layer_stat=True, - input_constructor=None): - assert type(input_res) is tuple, 'Please provide the size of the input image.' - assert len(input_res) >= 3, 'Input image should have 3 dimensions.' - flops_model = add_flops_counting_methods(model) - flops_model.eval().start_flops_count() - if input_constructor: - input = input_constructor(input_res) - _ = flops_model(**input) - else: - device = list(flops_model.parameters())[-1].device - batch = torch.FloatTensor(1, *input_res).to(device) - _ = flops_model(batch) - - if print_per_layer_stat: - print_model_with_flops(flops_model) - flops_count = flops_model.compute_average_flops_cost() - flops_model.stop_flops_count() - - return flops_count - -def get_model_activation(model, input_res, input_constructor=None): - assert type(input_res) is tuple, 'Please provide the size of the input image.' - assert len(input_res) >= 3, 'Input image should have 3 dimensions.' - activation_model = add_activation_counting_methods(model) - activation_model.eval().start_activation_count() - if input_constructor: - input = input_constructor(input_res) - _ = activation_model(**input) - else: - device = list(activation_model.parameters())[-1].device - batch = torch.FloatTensor(1, *input_res).to(device) - _ = activation_model(batch) - - activation_count, num_conv = activation_model.compute_average_activation_cost() - activation_model.stop_activation_count() - - return activation_count, num_conv - - -def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, - input_constructor=None): - assert type(input_res) is tuple - assert len(input_res) >= 3 - flops_model = add_flops_counting_methods(model) - flops_model.eval().start_flops_count() - if input_constructor: - input = input_constructor(input_res) - _ = flops_model(**input) - else: - batch = torch.FloatTensor(1, *input_res) - _ = flops_model(batch) - - if print_per_layer_stat: - print_model_with_flops(flops_model) - flops_count = flops_model.compute_average_flops_cost() - params_count = get_model_parameters_number(flops_model) - flops_model.stop_flops_count() - - if as_strings: - return flops_to_string(flops_count), params_to_string(params_count) - - return flops_count, params_count - - -def flops_to_string(flops, units='GMac', precision=2): - if units is None: - if flops // 10**9 > 0: - return str(round(flops / 10.**9, precision)) + ' GMac' - elif flops // 10**6 > 0: - return str(round(flops / 10.**6, precision)) + ' MMac' - elif flops // 10**3 > 0: - return str(round(flops / 10.**3, precision)) + ' KMac' - else: - return str(flops) + ' Mac' - else: - if units == 'GMac': - return str(round(flops / 10.**9, precision)) + ' ' + units - elif units == 'MMac': - return str(round(flops / 10.**6, precision)) + ' ' + units - elif units == 'KMac': - return str(round(flops / 10.**3, precision)) + ' ' + units - else: - return str(flops) + ' Mac' - - -def params_to_string(params_num): - if params_num // 10 ** 6 > 0: - return str(round(params_num / 10 ** 6, 2)) + ' M' - elif params_num // 10 ** 3: - return str(round(params_num / 10 ** 3, 2)) + ' k' - else: - return str(params_num) - - -def print_model_with_flops(model, units='GMac', precision=3): - total_flops = model.compute_average_flops_cost() - - def accumulate_flops(self): - if is_supported_instance(self): - return self.__flops__ / model.__batch_counter__ - else: - sum = 0 - for m in self.children(): - sum += m.accumulate_flops() - return sum - - def flops_repr(self): - accumulated_flops_cost = self.accumulate_flops() - return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), - '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), - self.original_extra_repr()]) - - def add_extra_repr(m): - m.accumulate_flops = accumulate_flops.__get__(m) - flops_extra_repr = flops_repr.__get__(m) - if m.extra_repr != flops_extra_repr: - m.original_extra_repr = m.extra_repr - m.extra_repr = flops_extra_repr - assert m.extra_repr != m.original_extra_repr - - def del_extra_repr(m): - if hasattr(m, 'original_extra_repr'): - m.extra_repr = m.original_extra_repr - del m.original_extra_repr - if hasattr(m, 'accumulate_flops'): - del m.accumulate_flops - - model.apply(add_extra_repr) - print(model) - model.apply(del_extra_repr) - - -def get_model_parameters_number(model): - params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) - return params_num - - -def add_flops_counting_methods(net_main_module): - # adding additional methods to the existing module object, - # this is done this way so that each function has access to self object - # embed() - net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) - net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) - net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) - net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) - - net_main_module.reset_flops_count() - return net_main_module - - -def compute_average_flops_cost(self): - """ - A method that will be available after add_flops_counting_methods() is called - on a desired net object. - - Returns current mean flops consumption per image. - - """ - - flops_sum = 0 - for module in self.modules(): - if is_supported_instance(module): - flops_sum += module.__flops__ - - return flops_sum - - -def start_flops_count(self): - """ - A method that will be available after add_flops_counting_methods() is called - on a desired net object. - - Activates the computation of mean flops consumption per image. - Call it before you run the network. - - """ - self.apply(add_flops_counter_hook_function) - - -def stop_flops_count(self): - """ - A method that will be available after add_flops_counting_methods() is called - on a desired net object. - - Stops computing the mean flops consumption per image. - Call whenever you want to pause the computation. - - """ - self.apply(remove_flops_counter_hook_function) - - -def reset_flops_count(self): - """ - A method that will be available after add_flops_counting_methods() is called - on a desired net object. - - Resets statistics computed so far. - - """ - self.apply(add_flops_counter_variable_or_reset) - - -def add_flops_counter_hook_function(module): - if is_supported_instance(module): - if hasattr(module, '__flops_handle__'): - return - - if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): - handle = module.register_forward_hook(conv_flops_counter_hook) - elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)): - handle = module.register_forward_hook(relu_flops_counter_hook) - elif isinstance(module, nn.Linear): - handle = module.register_forward_hook(linear_flops_counter_hook) - elif isinstance(module, (nn.BatchNorm2d)): - handle = module.register_forward_hook(bn_flops_counter_hook) - else: - handle = module.register_forward_hook(empty_flops_counter_hook) - module.__flops_handle__ = handle - - -def remove_flops_counter_hook_function(module): - if is_supported_instance(module): - if hasattr(module, '__flops_handle__'): - module.__flops_handle__.remove() - del module.__flops_handle__ - - -def add_flops_counter_variable_or_reset(module): - if is_supported_instance(module): - module.__flops__ = 0 - - -# ---- Internal functions -def is_supported_instance(module): - if isinstance(module, - ( - nn.Conv2d, nn.ConvTranspose2d, - nn.BatchNorm2d, - nn.Linear, - nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6, - )): - return True - - return False - - -def conv_flops_counter_hook(conv_module, input, output): - # Can have multiple inputs, getting the first one - # input = input[0] - - batch_size = output.shape[0] - output_dims = list(output.shape[2:]) - - kernel_dims = list(conv_module.kernel_size) - in_channels = conv_module.in_channels - out_channels = conv_module.out_channels - groups = conv_module.groups - - filters_per_channel = out_channels // groups - conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel - - active_elements_count = batch_size * np.prod(output_dims) - overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count) - - # overall_flops = overall_conv_flops - - conv_module.__flops__ += int(overall_conv_flops) - # conv_module.__output_dims__ = output_dims - - -def relu_flops_counter_hook(module, input, output): - active_elements_count = output.numel() - module.__flops__ += int(active_elements_count) - # print(module.__flops__, id(module)) - # print(module) - - -def linear_flops_counter_hook(module, input, output): - input = input[0] - if len(input.shape) == 1: - batch_size = 1 - module.__flops__ += int(batch_size * input.shape[0] * output.shape[0]) - else: - batch_size = input.shape[0] - module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) - - -def bn_flops_counter_hook(module, input, output): - # input = input[0] - # TODO: need to check here - # batch_flops = np.prod(input.shape) - # if module.affine: - # batch_flops *= 2 - # module.__flops__ += int(batch_flops) - batch = output.shape[0] - output_dims = output.shape[2:] - channels = module.num_features - batch_flops = batch * channels * np.prod(output_dims) - if module.affine: - batch_flops *= 2 - module.__flops__ += int(batch_flops) - - -# ---- Count the number of convolutional layers and the activation -def add_activation_counting_methods(net_main_module): - # adding additional methods to the existing module object, - # this is done this way so that each function has access to self object - # embed() - net_main_module.start_activation_count = start_activation_count.__get__(net_main_module) - net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module) - net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module) - net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module) - - net_main_module.reset_activation_count() - return net_main_module - - -def compute_average_activation_cost(self): - """ - A method that will be available after add_activation_counting_methods() is called - on a desired net object. - - Returns current mean activation consumption per image. - - """ - - activation_sum = 0 - num_conv = 0 - for module in self.modules(): - if is_supported_instance_for_activation(module): - activation_sum += module.__activation__ - num_conv += module.__num_conv__ - return activation_sum, num_conv - - -def start_activation_count(self): - """ - A method that will be available after add_activation_counting_methods() is called - on a desired net object. - - Activates the computation of mean activation consumption per image. - Call it before you run the network. - - """ - self.apply(add_activation_counter_hook_function) - - -def stop_activation_count(self): - """ - A method that will be available after add_activation_counting_methods() is called - on a desired net object. - - Stops computing the mean activation consumption per image. - Call whenever you want to pause the computation. - - """ - self.apply(remove_activation_counter_hook_function) - - -def reset_activation_count(self): - """ - A method that will be available after add_activation_counting_methods() is called - on a desired net object. - - Resets statistics computed so far. - - """ - self.apply(add_activation_counter_variable_or_reset) - - -def add_activation_counter_hook_function(module): - if is_supported_instance_for_activation(module): - if hasattr(module, '__activation_handle__'): - return - - if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): - handle = module.register_forward_hook(conv_activation_counter_hook) - module.__activation_handle__ = handle - - -def remove_activation_counter_hook_function(module): - if is_supported_instance_for_activation(module): - if hasattr(module, '__activation_handle__'): - module.__activation_handle__.remove() - del module.__activation_handle__ - - -def add_activation_counter_variable_or_reset(module): - if is_supported_instance_for_activation(module): - module.__activation__ = 0 - module.__num_conv__ = 0 - - -def is_supported_instance_for_activation(module): - if isinstance(module, - ( - nn.Conv2d, nn.ConvTranspose2d, - )): - return True - - return False - -def conv_activation_counter_hook(module, input, output): - """ - Calculate the activations in the convolutional operation. - Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces. - :param module: - :param input: - :param output: - :return: - """ - module.__activation__ += output.numel() - module.__num_conv__ += 1 - - -def empty_flops_counter_hook(module, input, output): - module.__flops__ += 0 - - -def upsample_flops_counter_hook(module, input, output): - output_size = output[0] - batch_size = output_size.shape[0] - output_elements_count = batch_size - for val in output_size.shape[1:]: - output_elements_count *= val - module.__flops__ += int(output_elements_count) - - -def pool_flops_counter_hook(module, input, output): - input = input[0] - module.__flops__ += int(np.prod(input.shape)) - - -def dconv_flops_counter_hook(dconv_module, input, output): - input = input[0] - - batch_size = input.shape[0] - output_dims = list(output.shape[2:]) - - m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape - out_channels, _, kernel_dim2, _, = dconv_module.projection.shape - # groups = dconv_module.groups - - # filters_per_channel = out_channels // groups - conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels - conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels - active_elements_count = batch_size * np.prod(output_dims) - - overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count - overall_flops = overall_conv_flops - - dconv_module.__flops__ += int(overall_flops) - # dconv_module.__output_dims__ = output_dims - - - - - diff --git a/core/data/deg_kair_utils/utils_option.py b/core/data/deg_kair_utils/utils_option.py deleted file mode 100644 index cf096210e2d8ea553b06a91ac5cdaa21127d837c..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_option.py +++ /dev/null @@ -1,255 +0,0 @@ -import os -from collections import OrderedDict -from datetime import datetime -import json -import re -import glob - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -# https://github.com/xinntao/BasicSR -# -------------------------------------------- -''' - - -def get_timestamp(): - return datetime.now().strftime('_%y%m%d_%H%M%S') - - -def parse(opt_path, is_train=True): - - # ---------------------------------------- - # remove comments starting with '//' - # ---------------------------------------- - json_str = '' - with open(opt_path, 'r') as f: - for line in f: - line = line.split('//')[0] + '\n' - json_str += line - - # ---------------------------------------- - # initialize opt - # ---------------------------------------- - opt = json.loads(json_str, object_pairs_hook=OrderedDict) - - opt['opt_path'] = opt_path - opt['is_train'] = is_train - - # ---------------------------------------- - # set default - # ---------------------------------------- - if 'merge_bn' not in opt: - opt['merge_bn'] = False - opt['merge_bn_startpoint'] = -1 - - if 'scale' not in opt: - opt['scale'] = 1 - - # ---------------------------------------- - # datasets - # ---------------------------------------- - for phase, dataset in opt['datasets'].items(): - phase = phase.split('_')[0] - dataset['phase'] = phase - dataset['scale'] = opt['scale'] # broadcast - dataset['n_channels'] = opt['n_channels'] # broadcast - if 'dataroot_H' in dataset and dataset['dataroot_H'] is not None: - dataset['dataroot_H'] = os.path.expanduser(dataset['dataroot_H']) - if 'dataroot_L' in dataset and dataset['dataroot_L'] is not None: - dataset['dataroot_L'] = os.path.expanduser(dataset['dataroot_L']) - - # ---------------------------------------- - # path - # ---------------------------------------- - for key, path in opt['path'].items(): - if path and key in opt['path']: - opt['path'][key] = os.path.expanduser(path) - - path_task = os.path.join(opt['path']['root'], opt['task']) - opt['path']['task'] = path_task - opt['path']['log'] = path_task - opt['path']['options'] = os.path.join(path_task, 'options') - - if is_train: - opt['path']['models'] = os.path.join(path_task, 'models') - opt['path']['images'] = os.path.join(path_task, 'images') - else: # test - opt['path']['images'] = os.path.join(path_task, 'test_images') - - # ---------------------------------------- - # network - # ---------------------------------------- - opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 - - # ---------------------------------------- - # GPU devices - # ---------------------------------------- - gpu_list = ','.join(str(x) for x in opt['gpu_ids']) - os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list - print('export CUDA_VISIBLE_DEVICES=' + gpu_list) - - # ---------------------------------------- - # default setting for distributeddataparallel - # ---------------------------------------- - if 'find_unused_parameters' not in opt: - opt['find_unused_parameters'] = True - if 'use_static_graph' not in opt: - opt['use_static_graph'] = False - if 'dist' not in opt: - opt['dist'] = False - opt['num_gpu'] = len(opt['gpu_ids']) - print('number of GPUs is: ' + str(opt['num_gpu'])) - - # ---------------------------------------- - # default setting for perceptual loss - # ---------------------------------------- - if 'F_feature_layer' not in opt['train']: - opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] - if 'F_weights' not in opt['train']: - opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] - if 'F_lossfn_type' not in opt['train']: - opt['train']['F_lossfn_type'] = 'l1' - if 'F_use_input_norm' not in opt['train']: - opt['train']['F_use_input_norm'] = True - if 'F_use_range_norm' not in opt['train']: - opt['train']['F_use_range_norm'] = False - - # ---------------------------------------- - # default setting for optimizer - # ---------------------------------------- - if 'G_optimizer_type' not in opt['train']: - opt['train']['G_optimizer_type'] = "adam" - if 'G_optimizer_betas' not in opt['train']: - opt['train']['G_optimizer_betas'] = [0.9,0.999] - if 'G_scheduler_restart_weights' not in opt['train']: - opt['train']['G_scheduler_restart_weights'] = 1 - if 'G_optimizer_wd' not in opt['train']: - opt['train']['G_optimizer_wd'] = 0 - if 'G_optimizer_reuse' not in opt['train']: - opt['train']['G_optimizer_reuse'] = False - if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: - opt['train']['D_optimizer_reuse'] = False - - # ---------------------------------------- - # default setting of strict for model loading - # ---------------------------------------- - if 'G_param_strict' not in opt['train']: - opt['train']['G_param_strict'] = True - if 'netD' in opt and 'D_param_strict' not in opt['path']: - opt['train']['D_param_strict'] = True - if 'E_param_strict' not in opt['path']: - opt['train']['E_param_strict'] = True - - # ---------------------------------------- - # Exponential Moving Average - # ---------------------------------------- - if 'E_decay' not in opt['train']: - opt['train']['E_decay'] = 0 - - # ---------------------------------------- - # default setting for discriminator - # ---------------------------------------- - if 'netD' in opt: - if 'net_type' not in opt['netD']: - opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet - if 'in_nc' not in opt['netD']: - opt['netD']['in_nc'] = 3 - if 'base_nc' not in opt['netD']: - opt['netD']['base_nc'] = 64 - if 'n_layers' not in opt['netD']: - opt['netD']['n_layers'] = 3 - if 'norm_type' not in opt['netD']: - opt['netD']['norm_type'] = 'spectral' - - - return opt - - -def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): - """ - Args: - save_dir: model folder - net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' - pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path - - Return: - init_iter: iteration number - init_path: model path - """ - file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) - if file_list: - iter_exist = [] - for file_ in file_list: - iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) - iter_exist.append(int(iter_current[0])) - init_iter = max(iter_exist) - init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) - else: - init_iter = 0 - init_path = pretrained_path - return init_iter, init_path - - -''' -# -------------------------------------------- -# convert the opt into json file -# -------------------------------------------- -''' - - -def save(opt): - opt_path = opt['opt_path'] - opt_path_copy = opt['path']['options'] - dirname, filename_ext = os.path.split(opt_path) - filename, ext = os.path.splitext(filename_ext) - dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) - with open(dump_path, 'w') as dump_file: - json.dump(opt, dump_file, indent=2) - - -''' -# -------------------------------------------- -# dict to string for logger -# -------------------------------------------- -''' - - -def dict2str(opt, indent_l=1): - msg = '' - for k, v in opt.items(): - if isinstance(v, dict): - msg += ' ' * (indent_l * 2) + k + ':[\n' - msg += dict2str(v, indent_l + 1) - msg += ' ' * (indent_l * 2) + ']\n' - else: - msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' - return msg - - -''' -# -------------------------------------------- -# convert OrderedDict to NoneDict, -# return None for missing key -# -------------------------------------------- -''' - - -def dict_to_nonedict(opt): - if isinstance(opt, dict): - new_opt = dict() - for key, sub_opt in opt.items(): - new_opt[key] = dict_to_nonedict(sub_opt) - return NoneDict(**new_opt) - elif isinstance(opt, list): - return [dict_to_nonedict(sub_opt) for sub_opt in opt] - else: - return opt - - -class NoneDict(dict): - def __missing__(self, key): - return None diff --git a/core/data/deg_kair_utils/utils_params.py b/core/data/deg_kair_utils/utils_params.py deleted file mode 100644 index def1cb79e11472b9b8ebbaae4bd83e7216af2ccb..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_params.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch - -import torchvision - -from models import basicblock as B - -def show_kv(net): - for k, v in net.items(): - print(k) - -# should run train debug mode first to get an initial model -#crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') -# -#for k, v in crt_net.items(): -# print(k) -#for k, v in crt_net.items(): -# if k in pretrained_net: -# crt_net[k] = pretrained_net[k] -# print('replace ... ', k) - -# x2 -> x4 -#crt_net['model.5.weight'] = pretrained_net['model.2.weight'] -#crt_net['model.5.bias'] = pretrained_net['model.2.bias'] -#crt_net['model.8.weight'] = pretrained_net['model.5.weight'] -#crt_net['model.8.bias'] = pretrained_net['model.5.bias'] -#crt_net['model.10.weight'] = pretrained_net['model.7.weight'] -#crt_net['model.10.bias'] = pretrained_net['model.7.bias'] -#torch.save(crt_net, '../pretrained_tmp.pth') - -# x2 -> x3 -''' -in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 -new_filter = torch.Tensor(576, 64, 3, 3) -new_filter[0:256, :, :, :] = in_filter -new_filter[256:512, :, :, :] = in_filter -new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] -crt_net['model.2.weight'] = new_filter - -in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 -new_bias = torch.Tensor(576) -new_bias[0:256] = in_bias -new_bias[256:512] = in_bias -new_bias[512:] = in_bias[0:576 - 512] -crt_net['model.2.bias'] = new_bias - -torch.save(crt_net, '../pretrained_tmp.pth') -''' - -# x2 -> x8 -''' -crt_net['model.5.weight'] = pretrained_net['model.2.weight'] -crt_net['model.5.bias'] = pretrained_net['model.2.bias'] -crt_net['model.8.weight'] = pretrained_net['model.2.weight'] -crt_net['model.8.bias'] = pretrained_net['model.2.bias'] -crt_net['model.11.weight'] = pretrained_net['model.5.weight'] -crt_net['model.11.bias'] = pretrained_net['model.5.bias'] -crt_net['model.13.weight'] = pretrained_net['model.7.weight'] -crt_net['model.13.bias'] = pretrained_net['model.7.bias'] -torch.save(crt_net, '../pretrained_tmp.pth') -''' - -# x3/4/8 RGB -> Y - -def rgb2gray_net(net, only_input=True): - - if only_input: - in_filter = net['0.weight'] - in_new_filter = in_filter[:,0,:,:]*0.2989 + in_filter[:,1,:,:]*0.587 + in_filter[:,2,:,:]*0.114 - in_new_filter.unsqueeze_(1) - net['0.weight'] = in_new_filter - -# out_filter = pretrained_net['model.13.weight'] -# out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ -# out_filter[2, :, :, :] * 0.114 -# out_new_filter.unsqueeze_(0) -# crt_net['model.13.weight'] = out_new_filter -# out_bias = pretrained_net['model.13.bias'] -# out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 -# out_new_bias = torch.Tensor(1).fill_(out_new_bias) -# crt_net['model.13.bias'] = out_new_bias - -# torch.save(crt_net, '../pretrained_tmp.pth') - - return net - - - -if __name__ == '__main__': - - net = torchvision.models.vgg19(pretrained=True) - for k,v in net.features.named_parameters(): - if k=='0.weight': - in_new_filter = v[:,0,:,:]*0.2989 + v[:,1,:,:]*0.587 + v[:,2,:,:]*0.114 - in_new_filter.unsqueeze_(1) - v = in_new_filter - print(v.shape) - print(v[0,0,0,0]) - if k=='0.bias': - in_new_bias = v - print(v[0]) - - print(net.features[0]) - - net.features[0] = B.conv(1, 64, mode='C') - - print(net.features[0]) - net.features[0].weight.data=in_new_filter - net.features[0].bias.data=in_new_bias - - for k,v in net.features.named_parameters(): - if k=='0.weight': - print(v[0,0,0,0]) - if k=='0.bias': - print(v[0]) - - # transfer parameters of old model to new one - model_old = torch.load(model_path) - state_dict = model.state_dict() - for ((key, param),(key2, param2)) in zip(model_old.items(), state_dict.items()): - state_dict[key2] = param - print([key, key2]) - # print([param.size(), param2.size()]) - torch.save(state_dict, 'model_new.pth') - - - # rgb2gray_net(net) - - - - - - - - - diff --git a/core/data/deg_kair_utils/utils_receptivefield.py b/core/data/deg_kair_utils/utils_receptivefield.py deleted file mode 100644 index 82ad613b9e744189e13b721a558dbc0f42c57b30..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_receptivefield.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -# online calculation: https://fomoro.com/research/article/receptive-field-calculator# - -# [filter size, stride, padding] -#Assume the two dimensions are the same -#Each kernel requires the following parameters: -# - k_i: kernel size -# - s_i: stride -# - p_i: padding (if padding is uneven, right padding will higher than left padding; "SAME" option in tensorflow) -# -#Each layer i requires the following parameters to be fully represented: -# - n_i: number of feature (data layer has n_1 = imagesize ) -# - j_i: distance (projected to image pixel distance) between center of two adjacent features -# - r_i: receptive field of a feature in layer i -# - start_i: position of the first feature's receptive field in layer i (idx start from 0, negative means the center fall into padding) - -import math - -def outFromIn(conv, layerIn): - n_in = layerIn[0] - j_in = layerIn[1] - r_in = layerIn[2] - start_in = layerIn[3] - k = conv[0] - s = conv[1] - p = conv[2] - - n_out = math.floor((n_in - k + 2*p)/s) + 1 - actualP = (n_out-1)*s - n_in + k - pR = math.ceil(actualP/2) - pL = math.floor(actualP/2) - - j_out = j_in * s - r_out = r_in + (k - 1)*j_in - start_out = start_in + ((k-1)/2 - pL)*j_in - return n_out, j_out, r_out, start_out - -def printLayer(layer, layer_name): - print(layer_name + ":") - print(" n features: %s jump: %s receptive size: %s start: %s " % (layer[0], layer[1], layer[2], layer[3])) - - - -layerInfos = [] -if __name__ == '__main__': - - convnet = [[3,1,1],[3,1,1],[3,1,1],[4,2,1],[2,2,0],[3,1,1]] - layer_names = ['conv1','conv2','conv3','conv4','conv5','conv6','conv7','conv8','conv9','conv10','conv11','conv12'] - imsize = 128 - - print ("-------Net summary------") - currentLayer = [imsize, 1, 1, 0.5] - printLayer(currentLayer, "input image") - for i in range(len(convnet)): - currentLayer = outFromIn(convnet[i], currentLayer) - layerInfos.append(currentLayer) - printLayer(currentLayer, layer_names[i]) - - -# run utils/utils_receptivefield.py - \ No newline at end of file diff --git a/core/data/deg_kair_utils/utils_regularizers.py b/core/data/deg_kair_utils/utils_regularizers.py deleted file mode 100644 index 17e7c8524b716f36e10b41d72fee2e375af69454..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_regularizers.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -import torch.nn as nn - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -''' - - -# -------------------------------------------- -# SVD Orthogonal Regularization -# -------------------------------------------- -def regularizer_orth(m): - """ - # ---------------------------------------- - # SVD Orthogonal Regularization - # ---------------------------------------- - # Applies regularization to the training by performing the - # orthogonalization technique described in the paper - # This function is to be called by the torch.nn.Module.apply() method, - # which applies svd_orthogonalization() to every layer of the model. - # usage: net.apply(regularizer_orth) - # ---------------------------------------- - """ - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - w = m.weight.data.clone() - c_out, c_in, f1, f2 = w.size() - # dtype = m.weight.data.type() - w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) - # self.netG.apply(svd_orthogonalization) - u, s, v = torch.svd(w) - s[s > 1.5] = s[s > 1.5] - 1e-4 - s[s < 0.5] = s[s < 0.5] + 1e-4 - w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) - m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) - else: - pass - - -# -------------------------------------------- -# SVD Orthogonal Regularization -# -------------------------------------------- -def regularizer_orth2(m): - """ - # ---------------------------------------- - # Applies regularization to the training by performing the - # orthogonalization technique described in the paper - # This function is to be called by the torch.nn.Module.apply() method, - # which applies svd_orthogonalization() to every layer of the model. - # usage: net.apply(regularizer_orth2) - # ---------------------------------------- - """ - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - w = m.weight.data.clone() - c_out, c_in, f1, f2 = w.size() - # dtype = m.weight.data.type() - w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) - u, s, v = torch.svd(w) - s_mean = s.mean() - s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 - s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 - w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) - m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) - else: - pass - - - -def regularizer_clip(m): - """ - # ---------------------------------------- - # usage: net.apply(regularizer_clip) - # ---------------------------------------- - """ - eps = 1e-4 - c_min = -1.5 - c_max = 1.5 - - classname = m.__class__.__name__ - if classname.find('Conv') != -1 or classname.find('Linear') != -1: - w = m.weight.data.clone() - w[w > c_max] -= eps - w[w < c_min] += eps - m.weight.data = w - - if m.bias is not None: - b = m.bias.data.clone() - b[b > c_max] -= eps - b[b < c_min] += eps - m.bias.data = b - -# elif classname.find('BatchNorm2d') != -1: -# -# rv = m.running_var.data.clone() -# rm = m.running_mean.data.clone() -# -# if m.affine: -# m.weight.data -# m.bias.data diff --git a/core/data/deg_kair_utils/utils_sisr.py b/core/data/deg_kair_utils/utils_sisr.py deleted file mode 100644 index e9edbd72ce53351d9e306c9774073a0e2eb0bdb3..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_sisr.py +++ /dev/null @@ -1,848 +0,0 @@ -# -*- coding: utf-8 -*- -from utils import utils_image as util -import random - -import scipy -import scipy.stats as ss -import scipy.io as io -from scipy import ndimage -from scipy.interpolate import interp2d - -import numpy as np -import torch - - -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# modified by Kai Zhang (github: https://github.com/cszn) -# 03/03/2020 -# -------------------------------------------- -""" - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -""" -# -------------------------------------------- -# calculate PCA projection matrix -# -------------------------------------------- -""" - - -def get_pca_matrix(x, dim_pca=15): - """ - Args: - x: 225x10000 matrix - dim_pca: 15 - Returns: - pca_matrix: 15x225 - """ - C = np.dot(x, x.T) - w, v = scipy.linalg.eigh(C) - pca_matrix = v[:, -dim_pca:].T - - return pca_matrix - - -def show_pca(x): - """ - x: PCA projection matrix, e.g., 15x225 - """ - for i in range(x.shape[0]): - xc = np.reshape(x[i, :], (int(np.sqrt(x.shape[1])), -1), order="F") - util.surf(xc) - - -def cal_pca_matrix(path='PCA_matrix.mat', ksize=15, l_max=12.0, dim_pca=15, num_samples=500): - kernels = np.zeros([ksize*ksize, num_samples], dtype=np.float32) - for i in range(num_samples): - - theta = np.pi*np.random.rand(1) - l1 = 0.1+l_max*np.random.rand(1) - l2 = 0.1+(l1-0.1)*np.random.rand(1) - - k = anisotropic_Gaussian(ksize=ksize, theta=theta[0], l1=l1[0], l2=l2[0]) - - # util.imshow(k) - - kernels[:, i] = np.reshape(k, (-1), order="F") # k.flatten(order='F') - - # io.savemat('k.mat', {'k': kernels}) - - pca_matrix = get_pca_matrix(kernels, dim_pca=dim_pca) - - io.savemat(path, {'p': pca_matrix}) - - return pca_matrix - - -""" -# -------------------------------------------- -# shifted anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z-MU - ZZ_t = ZZ.transpose(0,1,3,2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def gen_kernel(k_size=np.array([25, 25]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=12., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - sf = random.choice([1, 2, 3, 4]) - scale_factor = np.array([sf, sf]) - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = 0#-noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z-MU - ZZ_t = ZZ.transpose(0,1,3,2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1/sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - - ''' bicubic downsampling + blur - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - - Return: - downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -''' -# ================= -# Numpy -# ================= -''' - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH, image or kernel - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf-1)*0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w-1) - y1 = np.clip(y1, 0, h-1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -''' -# ================= -# pytorch -# ================= -''' - - -def splits(a, sf): - ''' - a: tensor NxCxWxHx2 - sf: scale factor - out: tensor NxCx(W/sf)x(H/sf)x2x(sf^2) - ''' - b = torch.stack(torch.chunk(a, sf, dim=2), dim=5) - b = torch.cat(torch.chunk(b, sf, dim=3), dim=5) - return b - - -def c2c(x): - return torch.from_numpy(np.stack([np.float32(x.real), np.float32(x.imag)], axis=-1)) - - -def r2c(x): - return torch.stack([x, torch.zeros_like(x)], -1) - - -def cdiv(x, y): - a, b = x[..., 0], x[..., 1] - c, d = y[..., 0], y[..., 1] - cd2 = c**2 + d**2 - return torch.stack([(a*c+b*d)/cd2, (b*c-a*d)/cd2], -1) - - -def csum(x, y): - return torch.stack([x[..., 0] + y, x[..., 1]], -1) - - -def cabs(x): - return torch.pow(x[..., 0]**2+x[..., 1]**2, 0.5) - - -def cmul(t1, t2): - ''' - complex multiplication - t1: NxCxHxWx2 - output: NxCxHxWx2 - ''' - real1, imag1 = t1[..., 0], t1[..., 1] - real2, imag2 = t2[..., 0], t2[..., 1] - return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim=-1) - - -def cconj(t, inplace=False): - ''' - # complex's conjugation - t: NxCxHxWx2 - output: NxCxHxWx2 - ''' - c = t.clone() if not inplace else t - c[..., 1] *= -1 - return c - - -def rfft(t): - return torch.rfft(t, 2, onesided=False) - - -def irfft(t): - return torch.irfft(t, 2, onesided=False) - - -def fft(t): - return torch.fft(t, 2) - - -def ifft(t): - return torch.ifft(t, 2) - - -def p2o(psf, shape): - ''' - Args: - psf: NxCxhxw - shape: [H,W] - - Returns: - otf: NxCxHxWx2 - ''' - otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) - otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) - for axis, axis_size in enumerate(psf.shape[2:]): - otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) - otf = torch.rfft(otf, 2, onesided=False) - n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) - otf[...,1][torch.abs(otf[...,1]) x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding) - ''' - x = torch.cat([x, x[:, :, 0:pad, :]], dim=2) - x = torch.cat([x, x[:, :, :, 0:pad]], dim=3) - x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2) - x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3) - return x - - -def pad_circular(input, padding): - # type: (Tensor, List[int]) -> Tensor - """ - Arguments - :param input: tensor of shape :math:`(N, C_{\text{in}}, H, [W, D]))` - :param padding: (tuple): m-elem tuple where m is the degree of convolution - Returns - :return: tensor of shape :math:`(N, C_{\text{in}}, [D + 2 * padding[0], - H + 2 * padding[1]], W + 2 * padding[2]))` - """ - offset = 3 - for dimension in range(input.dim() - offset + 1): - input = dim_pad_circular(input, padding[dimension], dimension + offset) - return input - - -def dim_pad_circular(input, padding, dimension): - # type: (Tensor, int, int) -> Tensor - input = torch.cat([input, input[[slice(None)] * (dimension - 1) + - [slice(0, padding)]]], dim=dimension - 1) - input = torch.cat([input[[slice(None)] * (dimension - 1) + - [slice(-2 * padding, -padding)]], input], dim=dimension - 1) - return input - - -def imfilter(x, k): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - ''' - x = pad_circular(x, padding=((k.shape[-2]-1)//2, (k.shape[-1]-1)//2)) - x = torch.nn.functional.conv2d(x, k, groups=x.shape[1]) - return x - - -def G(x, k, sf=3, center=False): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - sf: scale factor - center: the first one or the moddle one - - Matlab function: - tmp = imfilter(x,h,'circular'); - y = downsample2(tmp,K); - ''' - x = downsample(imfilter(x, k), sf=sf, center=center) - return x - - -def Gt(x, k, sf=3, center=False): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - sf: scale factor - center: the first one or the moddle one - - Matlab function: - tmp = upsample2(x,K); - y = imfilter(tmp,h,'circular'); - ''' - x = imfilter(upsample(x, sf=sf, center=center), k) - return x - - -def interpolation_down(x, sf, center=False): - mask = torch.zeros_like(x) - if center: - start = torch.tensor((sf-1)//2) - mask[..., start::sf, start::sf] = torch.tensor(1).type_as(x) - LR = x[..., start::sf, start::sf] - else: - mask[..., ::sf, ::sf] = torch.tensor(1).type_as(x) - LR = x[..., ::sf, ::sf] - y = x.mul(mask) - - return LR, y, mask - - -''' -# ================= -Numpy -# ================= -''' - - -def blockproc(im, blocksize, fun): - xblocks = np.split(im, range(blocksize[0], im.shape[0], blocksize[0]), axis=0) - xblocks_proc = [] - for xb in xblocks: - yblocks = np.split(xb, range(blocksize[1], im.shape[1], blocksize[1]), axis=1) - yblocks_proc = [] - for yb in yblocks: - yb_proc = fun(yb) - yblocks_proc.append(yb_proc) - xblocks_proc.append(np.concatenate(yblocks_proc, axis=1)) - - proc = np.concatenate(xblocks_proc, axis=0) - - return proc - - -def fun_reshape(a): - return np.reshape(a, (-1,1,a.shape[-1]), order='F') - - -def fun_mul(a, b): - return a*b - - -def BlockMM(nr, nc, Nb, m, x1): - ''' - myfun = @(block_struct) reshape(block_struct.data,m,1); - x1 = blockproc(x1,[nr nc],myfun); - x1 = reshape(x1,m,Nb); - x1 = sum(x1,2); - x = reshape(x1,nr,nc); - ''' - fun = fun_reshape - x1 = blockproc(x1, blocksize=(nr, nc), fun=fun) - x1 = np.reshape(x1, (m, Nb, x1.shape[-1]), order='F') - x1 = np.sum(x1, 1) - x = np.reshape(x1, (nr, nc, x1.shape[-1]), order='F') - return x - - -def INVLS(FB, FBC, F2B, FR, tau, Nb, nr, nc, m): - ''' - x1 = FB.*FR; - FBR = BlockMM(nr,nc,Nb,m,x1); - invW = BlockMM(nr,nc,Nb,m,F2B); - invWBR = FBR./(invW + tau*Nb); - fun = @(block_struct) block_struct.data.*invWBR; - FCBinvWBR = blockproc(FBC,[nr,nc],fun); - FX = (FR-FCBinvWBR)/tau; - Xest = real(ifft2(FX)); - ''' - x1 = FB*FR - FBR = BlockMM(nr, nc, Nb, m, x1) - invW = BlockMM(nr, nc, Nb, m, F2B) - invWBR = FBR/(invW + tau*Nb) - FCBinvWBR = blockproc(FBC, [nr, nc], lambda im: fun_mul(im, invWBR)) - FX = (FR-FCBinvWBR)/tau - Xest = np.real(np.fft.ifft2(FX, axes=(0, 1))) - return Xest - - -def psf2otf(psf, shape=None): - """ - Convert point-spread function to optical transfer function. - Compute the Fast Fourier Transform (FFT) of the point-spread - function (PSF) array and creates the optical transfer function (OTF) - array that is not influenced by the PSF off-centering. - By default, the OTF array is the same size as the PSF array. - To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF - post-pads the PSF array (down or to the right) with zeros to match - dimensions specified in OUTSIZE, then circularly shifts the values of - the PSF array up (or to the left) until the central pixel reaches (1,1) - position. - Parameters - ---------- - psf : `numpy.ndarray` - PSF array - shape : int - Output shape of the OTF array - Returns - ------- - otf : `numpy.ndarray` - OTF array - Notes - ----- - Adapted from MATLAB psf2otf function - """ - if type(shape) == type(None): - shape = psf.shape - shape = np.array(shape) - if np.all(psf == 0): - # return np.zeros_like(psf) - return np.zeros(shape) - if len(psf.shape) == 1: - psf = psf.reshape((1, psf.shape[0])) - inshape = psf.shape - psf = zero_pad(psf, shape, position='corner') - for axis, axis_size in enumerate(inshape): - psf = np.roll(psf, -int(axis_size / 2), axis=axis) - # Compute the OTF - otf = np.fft.fft2(psf, axes=(0, 1)) - # Estimate the rough number of operations involved in the FFT - # and discard the PSF imaginary part if within roundoff error - # roundoff error = machine epsilon = sys.float_info.epsilon - # or np.finfo().eps - n_ops = np.sum(psf.size * np.log2(psf.shape)) - otf = np.real_if_close(otf, tol=n_ops) - return otf - - -def zero_pad(image, shape, position='corner'): - """ - Extends image to a certain size with zeros - Parameters - ---------- - image: real 2d `numpy.ndarray` - Input image - shape: tuple of int - Desired output shape of the image - position : str, optional - The position of the input image in the output one: - * 'corner' - top-left corner (default) - * 'center' - centered - Returns - ------- - padded_img: real `numpy.ndarray` - The zero-padded image - """ - shape = np.asarray(shape, dtype=int) - imshape = np.asarray(image.shape, dtype=int) - if np.alltrue(imshape == shape): - return image - if np.any(shape <= 0): - raise ValueError("ZERO_PAD: null or negative shape given") - dshape = shape - imshape - if np.any(dshape < 0): - raise ValueError("ZERO_PAD: target size smaller than source one") - pad_img = np.zeros(shape, dtype=image.dtype) - idx, idy = np.indices(imshape) - if position == 'center': - if np.any(dshape % 2 != 0): - raise ValueError("ZERO_PAD: source and target shapes " - "have different parity.") - offx, offy = dshape // 2 - else: - offx, offy = (0, 0) - pad_img[idx + offx, idy + offy] = image - return pad_img - - -def upsample_np(x, sf=3, center=False): - st = (sf-1)//2 if center else 0 - z = np.zeros((x.shape[0]*sf, x.shape[1]*sf, x.shape[2])) - z[st::sf, st::sf, ...] = x - return z - - -def downsample_np(x, sf=3, center=False): - st = (sf-1)//2 if center else 0 - return x[st::sf, st::sf, ...] - - -def imfilter_np(x, k): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def G_np(x, k, sf=3, center=False): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - - Matlab function: - tmp = imfilter(x,h,'circular'); - y = downsample2(tmp,K); - ''' - x = downsample_np(imfilter_np(x, k), sf=sf, center=center) - return x - - -def Gt_np(x, k, sf=3, center=False): - ''' - x: image, NxcxHxW - k: kernel, cx1xhxw - - Matlab function: - tmp = upsample2(x,K); - y = imfilter(tmp,h,'circular'); - ''' - x = imfilter_np(upsample_np(x, sf=sf, center=center), k) - return x - - -if __name__ == '__main__': - img = util.imread_uint('test.bmp', 3) - - img = util.uint2single(img) - k = anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6) - util.imshow(k*10) - - - for sf in [2, 3, 4]: - - # modcrop - img = modcrop_np(img, sf=sf) - - # 1) bicubic degradation - img_b = bicubic_degradation(img, sf=sf) - print(img_b.shape) - - # 2) srmd degradation - img_s = srmd_degradation(img, k, sf=sf) - print(img_s.shape) - - # 3) dpsr degradation - img_d = dpsr_degradation(img, k, sf=sf) - print(img_d.shape) - - # 4) classical degradation - img_d = classical_degradation(img, k, sf=sf) - print(img_d.shape) - - k = anisotropic_Gaussian(ksize=7, theta=0.25*np.pi, l1=0.01, l2=0.01) - #print(k) -# util.imshow(k*10) - - k = shifted_anisotropic_Gaussian(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.8, max_var=10.8, noise_level=0.0) -# util.imshow(k*10) - - - # PCA -# pca_matrix = cal_pca_matrix(ksize=15, l_max=10.0, dim_pca=15, num_samples=12500) -# print(pca_matrix.shape) -# show_pca(pca_matrix) - # run utils/utils_sisr.py - # run utils_sisr.py - - - - - - - diff --git a/core/data/deg_kair_utils/utils_video.py b/core/data/deg_kair_utils/utils_video.py deleted file mode 100644 index 596dd4203098cf7b36f3d8499ccbf299623381ae..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_video.py +++ /dev/null @@ -1,493 +0,0 @@ -import os -import cv2 -import numpy as np -import torch -import random -from os import path as osp -from torch.nn import functional as F -from abc import ABCMeta, abstractmethod - - -def scandir(dir_path, suffix=None, recursive=False, full_path=False): - """Scan a directory to find the interested files. - - Args: - dir_path (str): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - full_path (bool, optional): If set to True, include the dir_path. - Default: False. - - Returns: - A generator for all the interested files with relative paths. - """ - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - if full_path: - return_path = entry.path - else: - return_path = osp.relpath(entry.path, root) - - if suffix is None: - yield return_path - elif return_path.endswith(suffix): - yield return_path - else: - if recursive: - yield from _scandir(entry.path, suffix=suffix, recursive=recursive) - else: - continue - - return _scandir(dir_path, suffix=suffix, recursive=recursive) - - -def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False): - """Read a sequence of images from a given folder path. - - Args: - path (list[str] | str): List of image paths or image folder path. - require_mod_crop (bool): Require mod crop for each image. - Default: False. - scale (int): Scale factor for mod_crop. Default: 1. - return_imgname(bool): Whether return image names. Default False. - - Returns: - Tensor: size (t, c, h, w), RGB, [0, 1]. - list[str]: Returned image name list. - """ - if isinstance(path, list): - img_paths = path - else: - img_paths = sorted(list(scandir(path, full_path=True))) - imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] - - if require_mod_crop: - imgs = [mod_crop(img, scale) for img in imgs] - imgs = img2tensor(imgs, bgr2rgb=True, float32=True) - imgs = torch.stack(imgs, dim=0) - - if return_imgname: - imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths] - return imgs, imgnames - else: - return imgs - - -def img2tensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - if img.dtype == 'float64': - img = img.astype('float32') - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) - - -def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): - """Convert torch Tensors into image numpy arrays. - - After clamping to [min, max], values will be normalized to [0, 1]. - - Args: - tensor (Tensor or list[Tensor]): Accept shapes: - 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); - 2) 3D Tensor of shape (3/1 x H x W); - 3) 2D Tensor of shape (H x W). - Tensor channel should be in RGB order. - rgb2bgr (bool): Whether to change rgb to bgr. - out_type (numpy type): output types. If ``np.uint8``, transform outputs - to uint8 type with range [0, 255]; otherwise, float type with - range [0, 1]. Default: ``np.uint8``. - min_max (tuple[int]): min and max values for clamp. - - Returns: - (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of - shape (H x W). The channel order is BGR. - """ - if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') - - if torch.is_tensor(tensor): - tensor = [tensor] - result = [] - for _tensor in tensor: - _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) - _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) - - n_dim = _tensor.dim() - if n_dim == 4: - img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() - img_np = img_np.transpose(1, 2, 0) - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) - elif n_dim == 3: - img_np = _tensor.numpy() - img_np = img_np.transpose(1, 2, 0) - if img_np.shape[2] == 1: # gray image - img_np = np.squeeze(img_np, axis=2) - else: - if rgb2bgr: - img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) - elif n_dim == 2: - img_np = _tensor.numpy() - else: - raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') - if out_type == np.uint8: - # Unlike MATLAB, numpy.unit8() WILL NOT round by default. - img_np = (img_np * 255.0).round() - img_np = img_np.astype(out_type) - result.append(img_np) - if len(result) == 1: - result = result[0] - return result - - -def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): - """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). - - We use vertical flip and transpose for rotation implementation. - All the images in the list use the same augmentation. - - Args: - imgs (list[ndarray] | ndarray): Images to be augmented. If the input - is an ndarray, it will be transformed to a list. - hflip (bool): Horizontal flip. Default: True. - rotation (bool): Ratotation. Default: True. - flows (list[ndarray]: Flows to be augmented. If the input is an - ndarray, it will be transformed to a list. - Dimension is (h, w, 2). Default: None. - return_status (bool): Return the status of flip and rotation. - Default: False. - - Returns: - list[ndarray] | ndarray: Augmented images and flows. If returned - results only have one element, just return ndarray. - - """ - hflip = hflip and random.random() < 0.5 - vflip = rotation and random.random() < 0.5 - rot90 = rotation and random.random() < 0.5 - - def _augment(img): - if hflip: # horizontal - cv2.flip(img, 1, img) - if vflip: # vertical - cv2.flip(img, 0, img) - if rot90: - img = img.transpose(1, 0, 2) - return img - - def _augment_flow(flow): - if hflip: # horizontal - cv2.flip(flow, 1, flow) - flow[:, :, 0] *= -1 - if vflip: # vertical - cv2.flip(flow, 0, flow) - flow[:, :, 1] *= -1 - if rot90: - flow = flow.transpose(1, 0, 2) - flow = flow[:, :, [1, 0]] - return flow - - if not isinstance(imgs, list): - imgs = [imgs] - imgs = [_augment(img) for img in imgs] - if len(imgs) == 1: - imgs = imgs[0] - - if flows is not None: - if not isinstance(flows, list): - flows = [flows] - flows = [_augment_flow(flow) for flow in flows] - if len(flows) == 1: - flows = flows[0] - return imgs, flows - else: - if return_status: - return imgs, (hflip, vflip, rot90) - else: - return imgs - - -def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): - """Paired random crop. Support Numpy array and Tensor inputs. - - It crops lists of lq and gt images with corresponding locations. - - Args: - img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images - should have the same shape. If the input is an ndarray, it will - be transformed to a list containing itself. - img_lqs (list[ndarray] | ndarray): LQ images. Note that all images - should have the same shape. If the input is an ndarray, it will - be transformed to a list containing itself. - gt_patch_size (int): GT patch size. - scale (int): Scale factor. - gt_path (str): Path to ground-truth. Default: None. - - Returns: - list[ndarray] | ndarray: GT images and LQ images. If returned results - only have one element, just return ndarray. - """ - - if not isinstance(img_gts, list): - img_gts = [img_gts] - if not isinstance(img_lqs, list): - img_lqs = [img_lqs] - - # determine input type: Numpy array or Tensor - input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' - - if input_type == 'Tensor': - h_lq, w_lq = img_lqs[0].size()[-2:] - h_gt, w_gt = img_gts[0].size()[-2:] - else: - h_lq, w_lq = img_lqs[0].shape[0:2] - h_gt, w_gt = img_gts[0].shape[0:2] - lq_patch_size = gt_patch_size // scale - - if h_gt != h_lq * scale or w_gt != w_lq * scale: - raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', - f'multiplication of LQ ({h_lq}, {w_lq}).') - if h_lq < lq_patch_size or w_lq < lq_patch_size: - raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' - f'({lq_patch_size}, {lq_patch_size}). ' - f'Please remove {gt_path}.') - - # randomly choose top and left coordinates for lq patch - top = random.randint(0, h_lq - lq_patch_size) - left = random.randint(0, w_lq - lq_patch_size) - - # crop lq patch - if input_type == 'Tensor': - img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] - else: - img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] - - # crop corresponding gt patch - top_gt, left_gt = int(top * scale), int(left * scale) - if input_type == 'Tensor': - img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] - else: - img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] - if len(img_gts) == 1: - img_gts = img_gts[0] - if len(img_lqs) == 1: - img_lqs = img_lqs[0] - return img_gts, img_lqs - - -# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: ``get()`` and ``get_text()``. - ``get()`` reads the file as a byte stream and ``get_text()`` reads the file - as texts. - """ - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass - - -class MemcachedBackend(BaseStorageBackend): - """Memcached storage backend. - - Attributes: - server_list_cfg (str): Config file for memcached server list. - client_cfg (str): Config file for memcached client. - sys_path (str | None): Additional path to be appended to `sys.path`. - Default: None. - """ - - def __init__(self, server_list_cfg, client_cfg, sys_path=None): - if sys_path is not None: - import sys - sys.path.append(sys_path) - try: - import mc - except ImportError: - raise ImportError('Please install memcached to enable MemcachedBackend.') - - self.server_list_cfg = server_list_cfg - self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) - # mc.pyvector servers as a point which points to a memory cache - self._mc_buffer = mc.pyvector() - - def get(self, filepath): - filepath = str(filepath) - import mc - self._client.Get(filepath, self._mc_buffer) - value_buf = mc.ConvertBuffer(self._mc_buffer) - return value_buf - - def get_text(self, filepath): - raise NotImplementedError - - -class HardDiskBackend(BaseStorageBackend): - """Raw hard disks storage backend.""" - - def get(self, filepath): - filepath = str(filepath) - with open(filepath, 'rb') as f: - value_buf = f.read() - return value_buf - - def get_text(self, filepath): - filepath = str(filepath) - with open(filepath, 'r') as f: - value_buf = f.read() - return value_buf - - -class LmdbBackend(BaseStorageBackend): - """Lmdb storage backend. - - Args: - db_paths (str | list[str]): Lmdb database paths. - client_keys (str | list[str]): Lmdb client keys. Default: 'default'. - readonly (bool, optional): Lmdb environment parameter. If True, - disallow any write operations. Default: True. - lock (bool, optional): Lmdb environment parameter. If False, when - concurrent access occurs, do not lock the database. Default: False. - readahead (bool, optional): Lmdb environment parameter. If False, - disable the OS filesystem readahead mechanism, which may improve - random read performance when a database is larger than RAM. - Default: False. - - Attributes: - db_paths (list): Lmdb database path. - _client (list): A list of several lmdb envs. - """ - - def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): - try: - import lmdb - except ImportError: - raise ImportError('Please install lmdb to enable LmdbBackend.') - - if isinstance(client_keys, str): - client_keys = [client_keys] - - if isinstance(db_paths, list): - self.db_paths = [str(v) for v in db_paths] - elif isinstance(db_paths, str): - self.db_paths = [str(db_paths)] - assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' - f'but received {len(client_keys)} and {len(self.db_paths)}.') - - self._client = {} - for client, path in zip(client_keys, self.db_paths): - self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) - - def get(self, filepath, client_key): - """Get values according to the filepath from one lmdb named client_key. - - Args: - filepath (str | obj:`Path`): Here, filepath is the lmdb key. - client_key (str): Used for distinguishing different lmdb envs. - """ - filepath = str(filepath) - assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') - client = self._client[client_key] - with client.begin(write=False) as txn: - value_buf = txn.get(filepath.encode('ascii')) - return value_buf - - def get_text(self, filepath): - raise NotImplementedError - - -class FileClient(object): - """A general file client to access files in different backend. - - The client loads a file or text in a specified backend from its path - and return it as a binary file. it can also register other backend - accessor with a given name and backend class. - - Attributes: - backend (str): The storage backend type. Options are "disk", - "memcached" and "lmdb". - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - 'disk': HardDiskBackend, - 'memcached': MemcachedBackend, - 'lmdb': LmdbBackend, - } - - def __init__(self, backend='disk', **kwargs): - if backend not in self._backends: - raise ValueError(f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') - self.backend = backend - self.client = self._backends[backend](**kwargs) - - def get(self, filepath, client_key='default'): - # client_key is used only for lmdb, where different fileclients have - # different lmdb environments. - if self.backend == 'lmdb': - return self.client.get(filepath, client_key) - else: - return self.client.get(filepath) - - def get_text(self, filepath): - return self.client.get_text(filepath) - - -def imfrombytes(content, flag='color', float32=False): - """Read an image from bytes. - - Args: - content (bytes): Image bytes got from files or other streams. - flag (str): Flags specifying the color type of a loaded image, - candidates are `color`, `grayscale` and `unchanged`. - float32 (bool): Whether to change to float32., If True, will also norm - to [0, 1]. Default: False. - - Returns: - ndarray: Loaded image array. - """ - img_np = np.frombuffer(content, np.uint8) - imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} - img = cv2.imdecode(img_np, imread_flags[flag]) - if float32: - img = img.astype(np.float32) / 255. - return img - diff --git a/core/data/deg_kair_utils/utils_videoio.py b/core/data/deg_kair_utils/utils_videoio.py deleted file mode 100644 index 5be8c7f06802d5aaa7155a1cdcb27d2838a0882c..0000000000000000000000000000000000000000 --- a/core/data/deg_kair_utils/utils_videoio.py +++ /dev/null @@ -1,555 +0,0 @@ -import os -import cv2 -import numpy as np -import torch -import random -from os import path as osp -from torchvision.utils import make_grid -import sys -from pathlib import Path -import six -from collections import OrderedDict -import math -import glob -import av -import io -from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, - CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, - CAP_PROP_POS_FRAMES, VideoWriter_fourcc) - -if sys.version_info <= (3, 3): - FileNotFoundError = IOError -else: - FileNotFoundError = FileNotFoundError - - -def is_str(x): - """Whether the input is an string instance.""" - return isinstance(x, six.string_types) - - -def is_filepath(x): - return is_str(x) or isinstance(x, Path) - - -def fopen(filepath, *args, **kwargs): - if is_str(filepath): - return open(filepath, *args, **kwargs) - elif isinstance(filepath, Path): - return filepath.open(*args, **kwargs) - raise ValueError('`filepath` should be a string or a Path') - - -def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): - if not osp.isfile(filename): - raise FileNotFoundError(msg_tmpl.format(filename)) - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == '': - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def symlink(src, dst, overwrite=True, **kwargs): - if os.path.lexists(dst) and overwrite: - os.remove(dst) - os.symlink(src, dst, **kwargs) - - -def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): - """Scan a directory to find the interested files. - Args: - dir_path (str | :obj:`Path`): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - case_sensitive (bool, optional) : If set to False, ignore the case of - suffix. Default: True. - Returns: - A generator for all the interested files with relative paths. - """ - if isinstance(dir_path, (str, Path)): - dir_path = str(dir_path) - else: - raise TypeError('"dir_path" must be a string or Path object') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - if suffix is not None and not case_sensitive: - suffix = suffix.lower() if isinstance(suffix, str) else tuple( - item.lower() for item in suffix) - - root = dir_path - - def _scandir(dir_path, suffix, recursive, case_sensitive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - _rel_path = rel_path if case_sensitive else rel_path.lower() - if suffix is None or _rel_path.endswith(suffix): - yield rel_path - elif recursive and os.path.isdir(entry.path): - # scan recursively if entry.path is a directory - yield from _scandir(entry.path, suffix, recursive, - case_sensitive) - - return _scandir(dir_path, suffix, recursive, case_sensitive) - - -class Cache: - - def __init__(self, capacity): - self._cache = OrderedDict() - self._capacity = int(capacity) - if capacity <= 0: - raise ValueError('capacity must be a positive integer') - - @property - def capacity(self): - return self._capacity - - @property - def size(self): - return len(self._cache) - - def put(self, key, val): - if key in self._cache: - return - if len(self._cache) >= self.capacity: - self._cache.popitem(last=False) - self._cache[key] = val - - def get(self, key, default=None): - val = self._cache[key] if key in self._cache else default - return val - - -class VideoReader: - """Video class with similar usage to a list object. - - This video warpper class provides convenient apis to access frames. - There exists an issue of OpenCV's VideoCapture class that jumping to a - certain frame may be inaccurate. It is fixed in this class by checking - the position after jumping each time. - Cache is used when decoding videos. So if the same frame is visited for - the second time, there is no need to decode again if it is stored in the - cache. - - """ - - def __init__(self, filename, cache_capacity=10): - # Check whether the video path is a url - if not filename.startswith(('https://', 'http://')): - check_file_exist(filename, 'Video file not found: ' + filename) - self._vcap = cv2.VideoCapture(filename) - assert cache_capacity > 0 - self._cache = Cache(cache_capacity) - self._position = 0 - # get basic info - self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH)) - self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT)) - self._fps = self._vcap.get(CAP_PROP_FPS) - self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT)) - self._fourcc = self._vcap.get(CAP_PROP_FOURCC) - - @property - def vcap(self): - """:obj:`cv2.VideoCapture`: The raw VideoCapture object.""" - return self._vcap - - @property - def opened(self): - """bool: Indicate whether the video is opened.""" - return self._vcap.isOpened() - - @property - def width(self): - """int: Width of video frames.""" - return self._width - - @property - def height(self): - """int: Height of video frames.""" - return self._height - - @property - def resolution(self): - """tuple: Video resolution (width, height).""" - return (self._width, self._height) - - @property - def fps(self): - """float: FPS of the video.""" - return self._fps - - @property - def frame_cnt(self): - """int: Total frames of the video.""" - return self._frame_cnt - - @property - def fourcc(self): - """str: "Four character code" of the video.""" - return self._fourcc - - @property - def position(self): - """int: Current cursor position, indicating frame decoded.""" - return self._position - - def _get_real_position(self): - return int(round(self._vcap.get(CAP_PROP_POS_FRAMES))) - - def _set_real_position(self, frame_id): - self._vcap.set(CAP_PROP_POS_FRAMES, frame_id) - pos = self._get_real_position() - for _ in range(frame_id - pos): - self._vcap.read() - self._position = frame_id - - def read(self): - """Read the next frame. - - If the next frame have been decoded before and in the cache, then - return it directly, otherwise decode, cache and return it. - - Returns: - ndarray or None: Return the frame if successful, otherwise None. - """ - # pos = self._position - if self._cache: - img = self._cache.get(self._position) - if img is not None: - ret = True - else: - if self._position != self._get_real_position(): - self._set_real_position(self._position) - ret, img = self._vcap.read() - if ret: - self._cache.put(self._position, img) - else: - ret, img = self._vcap.read() - if ret: - self._position += 1 - return img - - def get_frame(self, frame_id): - """Get frame by index. - - Args: - frame_id (int): Index of the expected frame, 0-based. - - Returns: - ndarray or None: Return the frame if successful, otherwise None. - """ - if frame_id < 0 or frame_id >= self._frame_cnt: - raise IndexError( - f'"frame_id" must be between 0 and {self._frame_cnt - 1}') - if frame_id == self._position: - return self.read() - if self._cache: - img = self._cache.get(frame_id) - if img is not None: - self._position = frame_id + 1 - return img - self._set_real_position(frame_id) - ret, img = self._vcap.read() - if ret: - if self._cache: - self._cache.put(self._position, img) - self._position += 1 - return img - - def current_frame(self): - """Get the current frame (frame that is just visited). - - Returns: - ndarray or None: If the video is fresh, return None, otherwise - return the frame. - """ - if self._position == 0: - return None - return self._cache.get(self._position - 1) - - def cvt2frames(self, - frame_dir, - file_start=0, - filename_tmpl='{:06d}.jpg', - start=0, - max_num=0, - show_progress=False): - """Convert a video to frame images. - - Args: - frame_dir (str): Output directory to store all the frame images. - file_start (int): Filenames will start from the specified number. - filename_tmpl (str): Filename template with the index as the - placeholder. - start (int): The starting frame index. - max_num (int): Maximum number of frames to be written. - show_progress (bool): Whether to show a progress bar. - """ - mkdir_or_exist(frame_dir) - if max_num == 0: - task_num = self.frame_cnt - start - else: - task_num = min(self.frame_cnt - start, max_num) - if task_num <= 0: - raise ValueError('start must be less than total frame number') - if start > 0: - self._set_real_position(start) - - def write_frame(file_idx): - img = self.read() - if img is None: - return - filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) - cv2.imwrite(filename, img) - - if show_progress: - pass - #track_progress(write_frame, range(file_start,file_start + task_num)) - else: - for i in range(task_num): - write_frame(file_start + i) - - def __len__(self): - return self.frame_cnt - - def __getitem__(self, index): - if isinstance(index, slice): - return [ - self.get_frame(i) - for i in range(*index.indices(self.frame_cnt)) - ] - # support negative indexing - if index < 0: - index += self.frame_cnt - if index < 0: - raise IndexError('index out of range') - return self.get_frame(index) - - def __iter__(self): - self._set_real_position(0) - return self - - def __next__(self): - img = self.read() - if img is not None: - return img - else: - raise StopIteration - - next = __next__ - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._vcap.release() - - -def frames2video(frame_dir, - video_file, - fps=30, - fourcc='XVID', - filename_tmpl='{:06d}.jpg', - start=0, - end=0, - show_progress=False): - """Read the frame images from a directory and join them as a video. - - Args: - frame_dir (str): The directory containing video frames. - video_file (str): Output filename. - fps (float): FPS of the output video. - fourcc (str): Fourcc of the output video, this should be compatible - with the output file type. - filename_tmpl (str): Filename template with the index as the variable. - start (int): Starting frame index. - end (int): Ending frame index. - show_progress (bool): Whether to show a progress bar. - """ - if end == 0: - ext = filename_tmpl.split('.')[-1] - end = len([name for name in scandir(frame_dir, ext)]) - first_file = osp.join(frame_dir, filename_tmpl.format(start)) - check_file_exist(first_file, 'The start frame not found: ' + first_file) - img = cv2.imread(first_file) - height, width = img.shape[:2] - resolution = (width, height) - vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps, - resolution) - - def write_frame(file_idx): - filename = osp.join(frame_dir, filename_tmpl.format(file_idx)) - img = cv2.imread(filename) - vwriter.write(img) - - if show_progress: - pass - # track_progress(write_frame, range(start, end)) - else: - for i in range(start, end): - write_frame(i) - vwriter.release() - - -def video2images(video_path, output_dir): - vidcap = cv2.VideoCapture(video_path) - in_fps = vidcap.get(cv2.CAP_PROP_FPS) - print('video fps:', in_fps) - if not os.path.isdir(output_dir): - os.makedirs(output_dir) - loaded, frame = vidcap.read() - total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) - print(f'number of total frames is: {total_frames:06}') - for i_frame in range(total_frames): - if i_frame % 100 == 0: - print(f'{i_frame:06} / {total_frames:06}') - frame_name = os.path.join(output_dir, f'{i_frame:06}' + '.png') - cv2.imwrite(frame_name, frame) - loaded, frame = vidcap.read() - - -def images2video(image_dir, video_path, fps=24, image_ext='png'): - ''' - #codec = cv2.VideoWriter_fourcc(*'XVID') - #codec = cv2.VideoWriter_fourcc('A','V','C','1') - #codec = cv2.VideoWriter_fourcc('Y','U','V','1') - #codec = cv2.VideoWriter_fourcc('P','I','M','1') - #codec = cv2.VideoWriter_fourcc('M','J','P','G') - codec = cv2.VideoWriter_fourcc('M','P','4','2') - #codec = cv2.VideoWriter_fourcc('D','I','V','3') - #codec = cv2.VideoWriter_fourcc('D','I','V','X') - #codec = cv2.VideoWriter_fourcc('U','2','6','3') - #codec = cv2.VideoWriter_fourcc('I','2','6','3') - #codec = cv2.VideoWriter_fourcc('F','L','V','1') - #codec = cv2.VideoWriter_fourcc('H','2','6','4') - #codec = cv2.VideoWriter_fourcc('A','Y','U','V') - #codec = cv2.VideoWriter_fourcc('I','U','Y','V') - ç¼–ç å™¨å¸¸ç”¨çš„几ç§ï¼š - cv2.VideoWriter_fourcc("I", "4", "2", "0") - 压缩的yuv颜色编ç å™¨ï¼Œ4:2:0色彩度å­é‡‡æ · 兼容性好,产生很大的视频 avi - cv2.VideoWriter_fourcc("P", I", "M", "1") - 采用mpeg-1ç¼–ç ï¼Œæ–‡ä»¶ä¸ºavi - cv2.VideoWriter_fourcc("X", "V", "T", "D") - 采用mpeg-4ç¼–ç ï¼Œå¾—到视频大å°å¹³å‡ 拓展åavi - cv2.VideoWriter_fourcc("T", "H", "E", "O") - Ogg Vorbis, 拓展å为ogv - cv2.VideoWriter_fourcc("F", "L", "V", "1") - FLASH视频,拓展å为.flv - ''' - image_files = sorted(glob.glob(os.path.join(image_dir, '*.{}'.format(image_ext)))) - print(len(image_files)) - height, width, _ = cv2.imread(image_files[0]).shape - out_fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') # cv2.VideoWriter_fourcc(*'MP4V') - out_video = cv2.VideoWriter(video_path, out_fourcc, fps, (width, height)) - - for image_file in image_files: - img = cv2.imread(image_file) - img = cv2.resize(img, (width, height), interpolation=3) - out_video.write(img) - out_video.release() - - -def add_video_compression(imgs): - codec_type = ['libx264', 'h264', 'mpeg4'] - codec_prob = [1 / 3., 1 / 3., 1 / 3.] - codec = random.choices(codec_type, codec_prob)[0] - # codec = 'mpeg4' - bitrate = [1e4, 1e5] - bitrate = np.random.randint(bitrate[0], bitrate[1] + 1) - - buf = io.BytesIO() - with av.open(buf, 'w', 'mp4') as container: - stream = container.add_stream(codec, rate=1) - stream.height = imgs[0].shape[0] - stream.width = imgs[0].shape[1] - stream.pix_fmt = 'yuv420p' - stream.bit_rate = bitrate - - for img in imgs: - img = np.uint8((img.clip(0, 1)*255.).round()) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame.pict_type = 'NONE' - # pdb.set_trace() - for packet in stream.encode(frame): - container.mux(packet) - - # Flush stream - for packet in stream.encode(): - container.mux(packet) - - outputs = [] - with av.open(buf, 'r', 'mp4') as container: - if container.streams.video: - for frame in container.decode(**{'video': 0}): - outputs.append( - frame.to_rgb().to_ndarray().astype(np.float32) / 255.) - - #outputs = np.stack(outputs, axis=0) - return outputs - - -if __name__ == '__main__': - - # ----------------------------------- - # test VideoReader(filename, cache_capacity=10) - # ----------------------------------- -# video_reader = VideoReader('utils/test.mp4') -# from utils import utils_image as util -# inputs = [] -# for frame in video_reader: -# print(frame.dtype) -# util.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) -# #util.imshow(np.flip(frame, axis=2)) - - # ----------------------------------- - # test video2images(video_path, output_dir) - # ----------------------------------- -# video2images('utils/test.mp4', 'frames') - - # ----------------------------------- - # test images2video(image_dir, video_path, fps=24, image_ext='png') - # ----------------------------------- -# images2video('frames', 'video_02.mp4', fps=30, image_ext='png') - - - # ----------------------------------- - # test frames2video(frame_dir, video_file, fps=30, fourcc='XVID', filename_tmpl='{:06d}.png') - # ----------------------------------- -# frames2video('frames', 'video_01.mp4', filename_tmpl='{:06d}.png') - - - # ----------------------------------- - # test add_video_compression(imgs) - # ----------------------------------- -# imgs = [] -# image_ext = 'png' -# frames = 'frames' -# from utils import utils_image as util -# image_files = sorted(glob.glob(os.path.join(frames, '*.{}'.format(image_ext)))) -# for i, image_file in enumerate(image_files): -# if i < 7: -# img = util.imread_uint(image_file, 3) -# img = util.uint2single(img) -# imgs.append(img) -# -# results = add_video_compression(imgs) -# for i, img in enumerate(results): -# util.imshow(util.single2uint(img)) -# util.imsave(util.single2uint(img),f'{i:05}.png') - - # run utils/utils_video.py - - - - - - - diff --git a/core/scripts/__init__.py b/core/scripts/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/core/scripts/cli.py b/core/scripts/cli.py deleted file mode 100644 index bfe3ecc330ecf9f0b3af1e7dc6b3758673712cc7..0000000000000000000000000000000000000000 --- a/core/scripts/cli.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -import argparse -from .. import WarpCore -from .. import templates - - -def template_init(args): - return '''' - - - '''.strip() - - -def init_template(args): - parser = argparse.ArgumentParser(description='WarpCore template init tool') - parser.add_argument('-t', '--template', type=str, default='WarpCore') - args = parser.parse_args(args) - - if args.template == 'WarpCore': - template_cls = WarpCore - else: - try: - template_cls = __import__(args.template) - except ModuleNotFoundError: - template_cls = getattr(templates, args.template) - print(template_cls) - - -def main(): - if len(sys.argv) < 2: - print('Usage: core ') - sys.exit(1) - if sys.argv[1] == 'init': - init_template(sys.argv[2:]) - else: - print('Unknown command') - sys.exit(1) - - -if __name__ == '__main__': - main() diff --git a/core/templates/__init__.py b/core/templates/__init__.py deleted file mode 100644 index 570f16de78bcce68aa49ff0a5d0fad63284f6948..0000000000000000000000000000000000000000 --- a/core/templates/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .diffusion import DiffusionCore \ No newline at end of file diff --git a/core/templates/diffusion.py b/core/templates/diffusion.py deleted file mode 100644 index f36dc3f5efa14669cc36cc3c0cffcc8def037289..0000000000000000000000000000000000000000 --- a/core/templates/diffusion.py +++ /dev/null @@ -1,236 +0,0 @@ -from .. import WarpCore -from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary -from abc import abstractmethod -from dataclasses import dataclass -import torch -from torch import nn -from torch.utils.data import DataLoader -from gdf import GDF -import numpy as np -from tqdm import tqdm -import wandb - -import webdataset as wds -from webdataset.handlers import warn_and_continue -from torch.distributed import barrier -from enum import Enum - -class TargetReparametrization(Enum): - EPSILON = 'epsilon' - X0 = 'x0' - -class DiffusionCore(WarpCore): - @dataclass(frozen=True) - class Config(WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - grad_accum_steps: int = EXPECTED_TRAIN - batch_size: int = EXPECTED_TRAIN - updates: int = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - save_every: int = 500 - backup_every: int = 20000 - use_fsdp: bool = True - - # EMA UPDATE - ema_start_iters: int = None - ema_iters: int = None - ema_beta: float = None - - # GDF setting - gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0 - - @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED - class Info(WarpCore.Info): - ema_loss: float = None - - @dataclass(frozen=True) - class Models(WarpCore.Models): - generator : nn.Module = EXPECTED - generator_ema : nn.Module = None # optional - - @dataclass(frozen=True) - class Optimizers(WarpCore.Optimizers): - generator : any = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - - # -------------------------------------------- - info: Info - config: Config - - @abstractmethod - def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def webdataset_path(self, extras: Extras): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def webdataset_filters(self, extras: Extras): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def webdataset_preprocessors(self, extras: Extras): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - raise NotImplementedError("This method needs to be overriden") - # ------------- - - def setup_data(self, extras: Extras) -> WarpCore.Data: - # SETUP DATASET - dataset_path = self.webdataset_path(extras) - preprocessors = self.webdataset_preprocessors(extras) - filters = self.webdataset_filters(extras) - - handler = warn_and_continue # None - # handler = None - dataset = wds.WebDataset( - dataset_path, resampled=True, handler=handler - ).select(filters).shuffle(690, handler=handler).decode( - "pilrgb", handler=handler - ).to_tuple( - *[p[0] for p in preprocessors], handler=handler - ).map_tuple( - *[p[1] for p in preprocessors], handler=handler - ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) - - # SETUP DATALOADER - real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True - ) - - return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) - - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - batch = next(data.iterator) - - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - latents = self.encode_latents(batch, models, extras) - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - - # FORWARD PASS - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred = models.generator(noised, noise_cond, **conditions) - if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: - pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss - target = noise - elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: - pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss - target = latents - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps - - return loss, loss_adjusted - - def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): - start_iter = self.info.iter+1 - max_iters = self.config.updates * self.config.grad_accum_steps - if self.is_main_node: - print(f"STARTING AT STEP: {start_iter}/{max_iters}") - - pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP - models.generator.train() - for i in pbar: - # FORWARD PASS - loss, loss_adjusted = self.forward_pass(data, extras, models) - - # BACKWARD PASS - if i % self.config.grad_accum_steps == 0 or i == max_iters: - loss_adjusted.backward() - grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - schedulers_dict[k].step() - models.generator.zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - with models.generator.no_sync(): - loss_adjusted.backward() - self.info.iter = i - - # UPDATE EMA - if models.generator_ema is not None and i % self.config.ema_iters == 0: - update_weights_ema( - models.generator_ema, models.generator, - beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) - ) - - # UPDATE LOSS METRICS - self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 - - if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): - wandb.alert( - title=f"NaN value encountered in training run {self.info.wandb_run_id}", - text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", - wait_duration=60*30 - ) - - if self.is_main_node: - logs = { - 'loss': self.info.ema_loss, - 'raw_loss': loss.mean().item(), - 'grad_norm': grad_norm.item(), - 'lr': optimizers.generator.param_groups[0]['lr'], - 'total_steps': self.info.total_steps, - } - - pbar.set_postfix(logs) - if self.config.wandb_project is not None: - wandb.log(logs) - - if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: - # SAVE AND CHECKPOINT STUFF - if np.isnan(loss.mean().item()): - if self.is_main_node and self.config.wandb_project is not None: - tqdm.write("Skipping sampling & checkpoint because the loss is NaN") - wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") - else: - self.save_checkpoints(models, optimizers) - if self.is_main_node: - create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') - self.sample(models, data, extras) - - def models_to_save(self): - return ['generator', 'generator_ema'] - - def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() - suffix = '' if suffix is None else suffix - self.save_info(self.info, suffix=suffix) - models_dict = models.to_dict() - optimizers_dict = optimizers.to_dict() - for key in self.models_to_save(): - model = models_dict[key] - if model is not None: - self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) - for key in optimizers_dict: - optimizer = optimizers_dict[key] - if optimizer is not None: - self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) - if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: - self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") - torch.cuda.empty_cache() diff --git a/core/utils/__init__.py b/core/utils/__init__.py deleted file mode 100644 index 2e71b37e8d1690a00ab1e0958320775bc822b6f5..0000000000000000000000000000000000000000 --- a/core/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .base_dto import Base, nested_dto, EXPECTED, EXPECTED_TRAIN -from .save_and_load import create_folder_if_necessary, safe_save, load_or_fail - -# MOVE IT SOMERWHERE ELSE -def update_weights_ema(tgt_model, src_model, beta=0.999): - for self_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1-beta) - for self_buffers, src_buffers in zip(tgt_model.buffers(), src_model.buffers()): - self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1-beta) \ No newline at end of file diff --git a/core/utils/__pycache__/__init__.cpython-310.pyc b/core/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 63c0a7e0fbf358f557d6bea755a0f550b4010a48..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/core/utils/__pycache__/__init__.cpython-39.pyc b/core/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 6f18d6921da3c9d93087c1b6d8eacd7a5e46a8e5..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/core/utils/__pycache__/base_dto.cpython-310.pyc b/core/utils/__pycache__/base_dto.cpython-310.pyc deleted file mode 100644 index de093eb65813d4abf69edfbb6923f2cabab21ad7..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/base_dto.cpython-310.pyc and /dev/null differ diff --git a/core/utils/__pycache__/base_dto.cpython-39.pyc b/core/utils/__pycache__/base_dto.cpython-39.pyc deleted file mode 100644 index b80d348c7959338709ec24c3ac24dfc4f6dab3dc..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/base_dto.cpython-39.pyc and /dev/null differ diff --git a/core/utils/__pycache__/save_and_load.cpython-310.pyc b/core/utils/__pycache__/save_and_load.cpython-310.pyc deleted file mode 100644 index a7a0f63ac8bbaf073dcd8a046ed112cec181d33a..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/save_and_load.cpython-310.pyc and /dev/null differ diff --git a/core/utils/__pycache__/save_and_load.cpython-39.pyc b/core/utils/__pycache__/save_and_load.cpython-39.pyc deleted file mode 100644 index ec04e9aba6f83ab76f0bbc243bb95fda07ad8d16..0000000000000000000000000000000000000000 Binary files a/core/utils/__pycache__/save_and_load.cpython-39.pyc and /dev/null differ diff --git a/core/utils/base_dto.py b/core/utils/base_dto.py deleted file mode 100644 index 7cf185f00e5c6f56d23774cec8591b8d4554971e..0000000000000000000000000000000000000000 --- a/core/utils/base_dto.py +++ /dev/null @@ -1,56 +0,0 @@ -import dataclasses -from dataclasses import dataclass, _MISSING_TYPE -from munch import Munch - -EXPECTED = "___REQUIRED___" -EXPECTED_TRAIN = "___REQUIRED_TRAIN___" - -# pylint: disable=invalid-field-call -def nested_dto(x, raw=False): - return dataclasses.field(default_factory=lambda: x if raw else Munch.fromDict(x)) - -@dataclass(frozen=True) -class Base: - training: bool = None - def __new__(cls, **kwargs): - training = kwargs.get('training', True) - setteable_fields = cls.setteable_fields(**kwargs) - mandatory_fields = cls.mandatory_fields(**kwargs) - invalid_kwargs = [ - {k: v} for k, v in kwargs.items() if k not in setteable_fields or v == EXPECTED or (v == EXPECTED_TRAIN and training is not False) - ] - print(mandatory_fields) - assert ( - len(invalid_kwargs) == 0 - ), f"Invalid fields detected when initializing this DTO: {invalid_kwargs}.\nDeclare this field and set it to None or EXPECTED in order to make it setteable." - missing_kwargs = [f for f in mandatory_fields if f not in kwargs] - assert ( - len(missing_kwargs) == 0 - ), f"Required fields missing initializing this DTO: {missing_kwargs}." - return object.__new__(cls) - - - @classmethod - def setteable_fields(cls, **kwargs): - return [f.name for f in dataclasses.fields(cls) if f.default is None or isinstance(f.default, _MISSING_TYPE) or f.default == EXPECTED or f.default == EXPECTED_TRAIN] - - @classmethod - def mandatory_fields(cls, **kwargs): - training = kwargs.get('training', True) - return [f.name for f in dataclasses.fields(cls) if isinstance(f.default, _MISSING_TYPE) and isinstance(f.default_factory, _MISSING_TYPE) or f.default == EXPECTED or (f.default == EXPECTED_TRAIN and training is not False)] - - @classmethod - def from_dict(cls, kwargs): - for k in kwargs: - if isinstance(kwargs[k], (dict, list, tuple)): - kwargs[k] = Munch.fromDict(kwargs[k]) - return cls(**kwargs) - - def to_dict(self): - # selfdict = dataclasses.asdict(self) # needs to pickle stuff, doesn't support some more complex classes - selfdict = {} - for k in dataclasses.fields(self): - selfdict[k.name] = getattr(self, k.name) - if isinstance(selfdict[k.name], Munch): - selfdict[k.name] = selfdict[k.name].toDict() - return selfdict diff --git a/core/utils/save_and_load.py b/core/utils/save_and_load.py deleted file mode 100644 index 0215f664f5a8e738147d0828b6a7e65b9c3a8507..0000000000000000000000000000000000000000 --- a/core/utils/save_and_load.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import torch -import json -from pathlib import Path -import safetensors -import wandb - - -def create_folder_if_necessary(path): - path = "/".join(path.split("/")[:-1]) - Path(path).mkdir(parents=True, exist_ok=True) - - -def safe_save(ckpt, path): - try: - os.remove(f"{path}.bak") - except OSError: - pass - try: - os.rename(path, f"{path}.bak") - except OSError: - pass - if path.endswith(".pt") or path.endswith(".ckpt"): - torch.save(ckpt, path) - elif path.endswith(".json"): - with open(path, "w", encoding="utf-8") as f: - json.dump(ckpt, f, indent=4) - elif path.endswith(".safetensors"): - safetensors.torch.save_file(ckpt, path) - else: - raise ValueError(f"File extension not supported: {path}") - - -def load_or_fail(path, wandb_run_id=None): - accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] - try: - assert any( - [path.endswith(ext) for ext in accepted_extensions] - ), f"Automatic loading not supported for this extension: {path}" - if not os.path.exists(path): - checkpoint = None - elif path.endswith(".pt") or path.endswith(".ckpt"): - checkpoint = torch.load(path, map_location="cpu") - elif path.endswith(".json"): - with open(path, "r", encoding="utf-8") as f: - checkpoint = json.load(f) - elif path.endswith(".safetensors"): - checkpoint = {} - with safetensors.safe_open(path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - return checkpoint - except Exception as e: - if wandb_run_id is not None: - wandb.alert( - title=f"Corrupt checkpoint for run {wandb_run_id}", - text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", - ) - raise e diff --git a/gdf/__init__.py b/gdf/__init__.py deleted file mode 100644 index 753b52e2e07e2540385594627a6faf4f6091b0a0..0000000000000000000000000000000000000000 --- a/gdf/__init__.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -from .scalers import * -from .targets import * -from .schedulers import * -from .noise_conditions import * -from .loss_weights import * -from .samplers import * -import torch.nn.functional as F -import math -class GDF(): - def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): - self.schedule = schedule - self.input_scaler = input_scaler - self.target = target - self.noise_cond = noise_cond - self.loss_weight = loss_weight - self.offset_noise = offset_noise - - def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): - stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) - return stretched_limits - - def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): - if epsilon is None: - epsilon = torch.randn_like(x0) - if self.offset_noise > 0: - if offset is None: - offset = torch.randn([x0.size(0), x0.size(1)] + [1]*(len(x0.shape)-2)).to(x0.device) - epsilon = epsilon + offset * self.offset_noise - logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) - a, b = self.input_scaler(logSNR) # B - if len(a.shape) == 1: - a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) # BxCxHxW - #print('in line 33 a b', a.shape, b.shape, x0.shape, logSNR.shape, logSNR, self.noise_cond(logSNR)) - target = self.target(x0, epsilon, logSNR, a, b) - - # noised, noise, logSNR, t_cond - #noised, noise, target, logSNR, noise_cond, loss_weight - return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) - - def undiffuse(self, x, logSNR, pred): - a, b = self.input_scaler(logSNR) - if len(a.shape) == 1: - a, b = a.view(-1, *[1]*(len(x.shape)-1)), b.view(-1, *[1]*(len(x.shape)-1)) - return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) - - def sample(self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): - sampler_params = {} if sampler_params is None else sampler_params - if sampler is None: - sampler = DDPMSampler(self) - r_range = torch.linspace(t_start, t_end, timesteps+1) - schedule = self.schedule if schedule is None else schedule - logSNR_range = schedule(r_range, shift=shift)[:, None].expand( - -1, shape[0] if x_init is None else x_init.size(0) - ).to(device) - - x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() - - if cfg is not None: - if unconditional_inputs is None: - unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} - model_inputs = { - k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) - else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) - else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) - else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) - } - - for i in range(0, timesteps): - noise_cond = self.noise_cond(logSNR_range[i]) - if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): - cfg_val = cfg - if isinstance(cfg_val, (list, tuple)): - assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" - cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) - - pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) - - pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) - if cfg_rho > 0: - std_pos, std_cfg = pred.std(), pred_cfg.std() - pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) - else: - pred = pred_cfg - else: - pred = model(x, noise_cond, **model_inputs) - x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) - x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) - #print('in line 86', x0.shape, x.shape, i, ) - altered_vars = yield (x0, x, pred) - - # Update some running variables if the user wants - if altered_vars is not None: - cfg = altered_vars.get('cfg', cfg) - cfg_rho = altered_vars.get('cfg_rho', cfg_rho) - sampler = altered_vars.get('sampler', sampler) - model_inputs = altered_vars.get('model_inputs', model_inputs) - x = altered_vars.get('x', x) - x_init = altered_vars.get('x_init', x_init) - -class GDF_dual_fixlrt(GDF): - def ref_noise(self, noised, x0, logSNR): - a, b = self.input_scaler(logSNR) - if len(a.shape) == 1: - a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) - #print('in line 210', a.shape, b.shape, x0.shape, noised.shape) - return self.target.noise_givenx0_noised(x0, noised, logSNR, a, b) - - def sample(self, model, model_inputs, shape, shape_lr, unconditional_inputs=None, sampler=None, - schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, - cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu"): - sampler_params = {} if sampler_params is None else sampler_params - if sampler is None: - sampler = DDPMSampler(self) - r_range = torch.linspace(t_start, t_end, timesteps+1) - schedule = self.schedule if schedule is None else schedule - logSNR_range = schedule(r_range, shift=shift)[:, None].expand( - -1, shape[0] if x_init is None else x_init.size(0) - ).to(device) - - x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() - x_lr = sampler.init_x(shape_lr).to(device) if x_init is None else x_init.clone() - if cfg is not None: - if unconditional_inputs is None: - unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} - model_inputs = { - k: torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) - else [torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None for vi, vi_u in zip(v, v_u)] if isinstance(v, list) - else {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) - else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) - } - - ###############################################lr sampling - - guide_feas = [None] * timesteps - - for i in range(0, timesteps): - noise_cond = self.noise_cond(logSNR_range[i]) - if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): - cfg_val = cfg - if isinstance(cfg_val, (list, tuple)): - assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" - cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) - - - - if i == timesteps -1 : - output, guide_lr_enc, guide_lr_dec = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) - guide_feas[i] = ([f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_enc], [f.chunk(2)[0].repeat(2, 1, 1, 1) for f in guide_lr_dec]) - else: - output, _, _ = model(torch.cat([x_lr, x_lr], dim=0), noise_cond.repeat(2), reuire_f=True, **model_inputs) - - pred, pred_unconditional = output.chunk(2) - - - pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) - if cfg_rho > 0: - std_pos, std_cfg = pred.std(), pred_cfg.std() - pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) - else: - pred = pred_cfg - else: - pred = model(x_lr, noise_cond, **model_inputs) - x0_lr, epsilon_lr = self.undiffuse(x_lr, logSNR_range[i], pred) - x_lr = sampler(x_lr, x0_lr, epsilon_lr, logSNR_range[i], logSNR_range[i+1], **sampler_params) - - ###############################################hr HR sampling - for i in range(0, timesteps): - noise_cond = self.noise_cond(logSNR_range[i]) - if cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start): - cfg_val = cfg - if isinstance(cfg_val, (list, tuple)): - assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" - cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1-r_range[i].item()) - - out_pred, t_emb = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), \ - lr_guide=guide_feas[timesteps -1] if i <=19 else None , **model_inputs, require_t=True, guide_weight=1 - i/timesteps) - pred, pred_unconditional = out_pred.chunk(2) - pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) - if cfg_rho > 0: - std_pos, std_cfg = pred.std(), pred_cfg.std() - pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) - else: - pred = pred_cfg - else: - pred = model(x, noise_cond, guide_lr=(guide_lr_enc, guide_lr_dec), **model_inputs) - x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) - - x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i+1], **sampler_params) - altered_vars = yield (x0, x, pred, x_lr) - - - - # Update some running variables if the user wants - if altered_vars is not None: - cfg = altered_vars.get('cfg', cfg) - cfg_rho = altered_vars.get('cfg_rho', cfg_rho) - sampler = altered_vars.get('sampler', sampler) - model_inputs = altered_vars.get('model_inputs', model_inputs) - x = altered_vars.get('x', x) - x_init = altered_vars.get('x_init', x_init) - - - - diff --git a/gdf/loss_weights.py b/gdf/loss_weights.py deleted file mode 100644 index d14ddaadeeb3f8de6c68aea4c364d9b852f2f15c..0000000000000000000000000000000000000000 --- a/gdf/loss_weights.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -import numpy as np - -# --- Loss Weighting -class BaseLossWeight(): - def weight(self, logSNR): - raise NotImplementedError("this method needs to be overridden") - - def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): - clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range - if shift != 1: - logSNR = logSNR.clone() + 2 * np.log(shift) - return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) - -class ComposedLossWeight(BaseLossWeight): - def __init__(self, div, mul): - self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul - self.div = [div] if isinstance(div, BaseLossWeight) else div - - def weight(self, logSNR): - prod, div = 1, 1 - for m in self.mul: - prod *= m.weight(logSNR) - for d in self.div: - div *= d.weight(logSNR) - return prod/div - -class ConstantLossWeight(BaseLossWeight): - def __init__(self, v=1): - self.v = v - - def weight(self, logSNR): - return torch.ones_like(logSNR) * self.v - -class SNRLossWeight(BaseLossWeight): - def weight(self, logSNR): - return logSNR.exp() - -class P2LossWeight(BaseLossWeight): - def __init__(self, k=1.0, gamma=1.0, s=1.0): - self.k, self.gamma, self.s = k, gamma, s - - def weight(self, logSNR): - return (self.k + (logSNR * self.s).exp()) ** -self.gamma - -class SNRPlusOneLossWeight(BaseLossWeight): - def weight(self, logSNR): - return logSNR.exp() + 1 - -class MinSNRLossWeight(BaseLossWeight): - def __init__(self, max_snr=5): - self.max_snr = max_snr - - def weight(self, logSNR): - return logSNR.exp().clamp(max=self.max_snr) - -class MinSNRPlusOneLossWeight(BaseLossWeight): - def __init__(self, max_snr=5): - self.max_snr = max_snr - - def weight(self, logSNR): - return (logSNR.exp() + 1).clamp(max=self.max_snr) - -class TruncatedSNRLossWeight(BaseLossWeight): - def __init__(self, min_snr=1): - self.min_snr = min_snr - - def weight(self, logSNR): - return logSNR.exp().clamp(min=self.min_snr) - -class SechLossWeight(BaseLossWeight): - def __init__(self, div=2): - self.div = div - - def weight(self, logSNR): - return 1/(logSNR/self.div).cosh() - -class DebiasedLossWeight(BaseLossWeight): - def weight(self, logSNR): - return 1/logSNR.exp().sqrt() - -class SigmoidLossWeight(BaseLossWeight): - def __init__(self, s=1): - self.s = s - - def weight(self, logSNR): - return (logSNR * self.s).sigmoid() - -class AdaptiveLossWeight(BaseLossWeight): - def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): - self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets-1) - self.bucket_losses = torch.ones(buckets) - self.weight_range = weight_range - - def weight(self, logSNR): - indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) - return (1/self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) - - def update_buckets(self, logSNR, loss, beta=0.99): - indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() - self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta) diff --git a/gdf/noise_conditions.py b/gdf/noise_conditions.py deleted file mode 100644 index dc2791f50a6f63eff8f9bed9b827f87517cc0be8..0000000000000000000000000000000000000000 --- a/gdf/noise_conditions.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import numpy as np - -class BaseNoiseCond(): - def __init__(self, *args, shift=1, clamp_range=None, **kwargs): - clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range - self.shift = shift - self.clamp_range = clamp_range - self.setup(*args, **kwargs) - - def setup(self, *args, **kwargs): - pass # this method is optional, override it if required - - def cond(self, logSNR): - raise NotImplementedError("this method needs to be overriden") - - def __call__(self, logSNR): - if self.shift != 1: - logSNR = logSNR.clone() + 2 * np.log(self.shift) - return self.cond(logSNR).clamp(*self.clamp_range) - -class CosineTNoiseCond(BaseNoiseCond): - def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] - self.s = torch.tensor([s]) - self.clamp_range = clamp_range - self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 - - def cond(self, logSNR): - var = logSNR.sigmoid() - var = var.clamp(*self.clamp_range) - s, min_var = self.s.to(var.device), self.min_var.to(var.device) - t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s - return t - -class EDMNoiseCond(BaseNoiseCond): - def cond(self, logSNR): - return -logSNR/8 - -class SigmoidNoiseCond(BaseNoiseCond): - def cond(self, logSNR): - return (-logSNR).sigmoid() - -class LogSNRNoiseCond(BaseNoiseCond): - def cond(self, logSNR): - return logSNR - -class EDMSigmaNoiseCond(BaseNoiseCond): - def setup(self, sigma_data=1): - self.sigma_data = sigma_data - - def cond(self, logSNR): - return torch.exp(-logSNR / 2) * self.sigma_data - -class RectifiedFlowsNoiseCond(BaseNoiseCond): - def cond(self, logSNR): - _a = logSNR.exp() - 1 - _a[_a == 0] = 1e-3 # Avoid division by zero - a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) - return a - -# Any NoiseCond that cannot be described easily as a continuous function of t -# It needs to define self.x and self.y in the setup() method -class PiecewiseLinearNoiseCond(BaseNoiseCond): - def setup(self): - self.x = None - self.y = None - - def piecewise_linear(self, y, xs, ys): - indices = (len(xs)-2) - torch.searchsorted(ys.flip(dims=(-1,))[:-2], y) - x_min, x_max = xs[indices], xs[indices+1] - y_min, y_max = ys[indices], ys[indices+1] - x = x_min + (x_max - x_min) * (y - y_min) / (y_max - y_min) - return x - - def cond(self, logSNR): - var = logSNR.sigmoid() - t = self.piecewise_linear(var, self.x.to(var.device), self.y.to(var.device)) # .mul(1000).round().clamp(min=0) - return t - -class StableDiffusionNoiseCond(PiecewiseLinearNoiseCond): - def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): - self.total_steps = total_steps - linear_range_sqrt = [r**0.5 for r in linear_range] - self.x = torch.linspace(0, 1, total_steps+1) - - alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 - self.y = alphas.cumprod(dim=-1) - - def cond(self, logSNR): - return super().cond(logSNR).clamp(0, 1) - -class DiscreteNoiseCond(BaseNoiseCond): - def setup(self, noise_cond, steps=1000, continuous_range=[0, 1]): - self.noise_cond = noise_cond - self.steps = steps - self.continuous_range = continuous_range - - def cond(self, logSNR): - cond = self.noise_cond(logSNR) - cond = (cond-self.continuous_range[0]) / (self.continuous_range[1]-self.continuous_range[0]) - return cond.mul(self.steps).long() - \ No newline at end of file diff --git a/gdf/readme.md b/gdf/readme.md deleted file mode 100644 index 9a63691513c9da6804fba53e36acc8e0cd7f5d7f..0000000000000000000000000000000000000000 --- a/gdf/readme.md +++ /dev/null @@ -1,86 +0,0 @@ -# Generic Diffusion Framework (GDF) - -# Basic usage -GDF is a simple framework for working with diffusion models. It implements most common diffusion frameworks (DDPM / DDIM -, EDM, Rectified Flows, etc.) and makes it very easy to switch between them or combine different parts of different -frameworks - -Using GDF is very straighforward, first of all just define an instance of the GDF class: - -```python -from gdf import GDF -from gdf import CosineSchedule -from gdf import VPScaler, EpsilonTarget, CosineTNoiseCond, P2LossWeight - -gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=P2LossWeight(), -) -``` - -You need to define the following components: -* **Train Schedule**: This will return the logSNR schedule that will be used during training, some of the schedulers can be configured. A train schedule will then be called with a batch size and will randomly sample some values from the defined distribution. -* **Sample Schedule**: This is the schedule that will be used later on when sampling. It might be different from the training schedule. -* **Input Scaler**: If you want to use Variance Preserving or LERP (rectified flows) -* **Target**: What the target is during training, usually: epsilon, x0 or v -* **Noise Conditioning**: You could directly pass the logSNR to your model but usually a normalized value is used instead, for example the EDM framework proposes to use `-logSNR/8` -* **Loss Weight**: There are many proposed loss weighting strategies, here you define which one you'll use - -All of those classes are actually very simple logSNR centric definitions, for example the VPScaler is defined as just: -```python -class VPScaler(): - def __call__(self, logSNR): - a_squared = logSNR.sigmoid() - a = a_squared.sqrt() - b = (1-a_squared).sqrt() - return a, b - -``` - -So it's very easy to extend this framework with custom schedulers, scalers, targets, loss weights, etc... - -### Training - -When you define your training loop you can get all you need by just doing: -```python -shift, loss_shift = 1, 1 # this can be set to higher values as per what the Simple Diffusion paper sugested for high resolution -for inputs, extra_conditions in dataloader_iterator: - noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(inputs, shift=shift, loss_shift=loss_shift) - pred = diffusion_model(noised, noise_cond, extra_conditions) - - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - loss_adjusted = (loss * loss_weight).mean() - - loss_adjusted.backward() - optimizer.step() - optimizer.zero_grad(set_to_none=True) -``` - -And that's all, you have a diffusion model training, where it's very easy to customize the different elements of the -training from the GDF class. - -### Sampling - -The other important part is sampling, when you want to use this framework to sample you can just do the following: - -```python -from gdf import DDPMSampler - -shift = 1 -sampling_configs = { - "timesteps": 30, "cfg": 7, "sampler": DDPMSampler(gdf), "shift": shift, - "schedule": CosineSchedule(clamp_range=[0.0001, 0.9999]) -} - -*_, (sampled, _, _) = gdf.sample( - diffusion_model, {"cond": extra_conditions}, latents.shape, - unconditional_inputs= {"cond": torch.zeros_like(extra_conditions)}, - device=device, **sampling_configs -) -``` - -# Available modules - -TODO diff --git a/gdf/samplers.py b/gdf/samplers.py deleted file mode 100644 index b6048c86a261d53d0440a3b2c1591a03d9978c4f..0000000000000000000000000000000000000000 --- a/gdf/samplers.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch - -class SimpleSampler(): - def __init__(self, gdf): - self.gdf = gdf - self.current_step = -1 - - def __call__(self, *args, **kwargs): - self.current_step += 1 - return self.step(*args, **kwargs) - - def init_x(self, shape): - return torch.randn(*shape) - - def step(self, x, x0, epsilon, logSNR, logSNR_prev): - raise NotImplementedError("You should override the 'apply' function.") - -class DDIMSampler(SimpleSampler): - def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): - a, b = self.gdf.input_scaler(logSNR) - if len(a.shape) == 1: - a, b = a.view(-1, *[1]*(len(x0.shape)-1)), b.view(-1, *[1]*(len(x0.shape)-1)) - - a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) - if len(a_prev.shape) == 1: - a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) - - sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 - # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) - x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) - return x - -class DDPMSampler(DDIMSampler): - def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): - return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) - -class LCMSampler(SimpleSampler): - def step(self, x, x0, epsilon, logSNR, logSNR_prev): - a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) - if len(a_prev.shape) == 1: - a_prev, b_prev = a_prev.view(-1, *[1]*(len(x0.shape)-1)), b_prev.view(-1, *[1]*(len(x0.shape)-1)) - return x0 * a_prev + torch.randn_like(epsilon) * b_prev - \ No newline at end of file diff --git a/gdf/scalers.py b/gdf/scalers.py deleted file mode 100644 index b1adb8b0269667f3d006c7d7d17cbf2b7ef56ca9..0000000000000000000000000000000000000000 --- a/gdf/scalers.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch - -class BaseScaler(): - def __init__(self): - self.stretched_limits = None - - def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): - min_logSNR = schedule(torch.ones(1), shift=shift) - max_logSNR = schedule(torch.zeros(1), shift=shift) - - min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] - max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] - self.stretched_limits = [min_a, max_a, min_b, max_b] - return self.stretched_limits - - def stretch_limits(self, a, b): - min_a, max_a, min_b, max_b = self.stretched_limits - return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) - - def scalers(self, logSNR): - raise NotImplementedError("this method needs to be overridden") - - def __call__(self, logSNR): - a, b = self.scalers(logSNR) - if self.stretched_limits is not None: - a, b = self.stretch_limits(a, b) - return a, b - -class VPScaler(BaseScaler): - def scalers(self, logSNR): - a_squared = logSNR.sigmoid() - a = a_squared.sqrt() - b = (1-a_squared).sqrt() - return a, b - -class LERPScaler(BaseScaler): - def scalers(self, logSNR): - _a = logSNR.exp() - 1 - _a[_a == 0] = 1e-3 # Avoid division by zero - a = 1 + (2-(2**2 + 4*_a)**0.5) / (2*_a) - b = 1-a - return a, b diff --git a/gdf/schedulers.py b/gdf/schedulers.py deleted file mode 100644 index caa6e174da1d766ea5828616bb8113865106b628..0000000000000000000000000000000000000000 --- a/gdf/schedulers.py +++ /dev/null @@ -1,200 +0,0 @@ -import torch -import numpy as np - -class BaseSchedule(): - def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): - self.setup(*args, **kwargs) - self.limits = None - self.discrete_steps = discrete_steps - self.shift = shift - if force_limits: - self.reset_limits() - - def reset_limits(self, shift=1, disable=False): - try: - self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max - return self.limits - except Exception: - print("WARNING: this schedule doesn't support t and will be unbounded") - return None - - def setup(self, *args, **kwargs): - raise NotImplementedError("this method needs to be overriden") - - def schedule(self, *args, **kwargs): - raise NotImplementedError("this method needs to be overriden") - - def __call__(self, t, *args, shift=1, **kwargs): - if isinstance(t, torch.Tensor): - batch_size = None - if self.discrete_steps is not None: - if t.dtype != torch.long: - t = (t * (self.discrete_steps-1)).round().long() - t = t / (self.discrete_steps-1) - t = t.clamp(0, 1) - else: - batch_size = t - t = None - logSNR = self.schedule(t, batch_size, *args, **kwargs) - if shift*self.shift != 1: - logSNR += 2 * np.log(1/(shift*self.shift)) - if self.limits is not None: - logSNR = logSNR.clamp(*self.limits) - return logSNR - -class CosineSchedule(BaseSchedule): - def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): - self.s = torch.tensor([s]) - self.clamp_range = clamp_range - self.norm_instead = norm_instead - self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 - - def schedule(self, t, batch_size): - if t is None: - t = (1-torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) - s, min_var = self.s.to(t.device), self.min_var.to(t.device) - var = torch.cos((s + t)/(1+s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var - if self.norm_instead: - var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] - else: - var = var.clamp(*self.clamp_range) - logSNR = (var/(1-var)).log() - return logSNR - -class CosineSchedule2(BaseSchedule): - def setup(self, logsnr_range=[-15, 15]): - self.t_min = np.arctan(np.exp(-0.5 * logsnr_range[1])) - self.t_max = np.arctan(np.exp(-0.5 * logsnr_range[0])) - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - return -2 * (self.t_min + t*(self.t_max-self.t_min)).tan().log() - -class SqrtSchedule(BaseSchedule): - def setup(self, s=1e-4, clamp_range=[0.0001, 0.9999], norm_instead=False): - self.s = s - self.clamp_range = clamp_range - self.norm_instead = norm_instead - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - var = 1 - (t + self.s)**0.5 - if self.norm_instead: - var = var * (self.clamp_range[1]-self.clamp_range[0]) + self.clamp_range[0] - else: - var = var.clamp(*self.clamp_range) - logSNR = (var/(1-var)).log() - return logSNR - -class RectifiedFlowsSchedule(BaseSchedule): - def setup(self, logsnr_range=[-15, 15]): - self.logsnr_range = logsnr_range - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - logSNR = (((1-t)**2)/(t**2)).log() - logSNR = logSNR.clamp(*self.logsnr_range) - return logSNR - -class EDMSampleSchedule(BaseSchedule): - def setup(self, sigma_range=[0.002, 80], p=7): - self.sigma_range = sigma_range - self.p = p - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - smin, smax, p = *self.sigma_range, self.p - sigma = (smax ** (1/p) + (1-t) * (smin ** (1/p) - smax ** (1/p))) ** p - logSNR = (1/sigma**2).log() - return logSNR - -class EDMTrainSchedule(BaseSchedule): - def setup(self, mu=-1.2, std=1.2): - self.mu = mu - self.std = std - - def schedule(self, t, batch_size): - if t is not None: - raise Exception("EDMTrainSchedule doesn't support passing timesteps: t") - logSNR = -2*(torch.randn(batch_size) * self.std - self.mu) - return logSNR - -class LinearSchedule(BaseSchedule): - def setup(self, logsnr_range=[-10, 10]): - self.logsnr_range = logsnr_range - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - logSNR = t * (self.logsnr_range[0]-self.logsnr_range[1]) + self.logsnr_range[1] - return logSNR - -# Any schedule that cannot be described easily as a continuous function of t -# It needs to define self.x and self.y in the setup() method -class PiecewiseLinearSchedule(BaseSchedule): - def setup(self): - self.x = None - self.y = None - - def piecewise_linear(self, x, xs, ys): - indices = torch.searchsorted(xs[:-1], x) - 1 - x_min, x_max = xs[indices], xs[indices+1] - y_min, y_max = ys[indices], ys[indices+1] - var = y_min + (y_max - y_min) * (x - x_min) / (x_max - x_min) - return var - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - var = self.piecewise_linear(t, self.x.to(t.device), self.y.to(t.device)) - logSNR = (var/(1-var)).log() - return logSNR - -class StableDiffusionSchedule(PiecewiseLinearSchedule): - def setup(self, linear_range=[0.00085, 0.012], total_steps=1000): - linear_range_sqrt = [r**0.5 for r in linear_range] - self.x = torch.linspace(0, 1, total_steps+1) - - alphas = 1-(linear_range_sqrt[0]*(1-self.x) + linear_range_sqrt[1]*self.x)**2 - self.y = alphas.cumprod(dim=-1) - -class AdaptiveTrainSchedule(BaseSchedule): - def setup(self, logsnr_range=[-10, 10], buckets=100, min_probs=0.0): - th = torch.linspace(logsnr_range[0], logsnr_range[1], buckets+1) - self.bucket_ranges = torch.tensor([(th[i], th[i+1]) for i in range(buckets)]) - self.bucket_probs = torch.ones(buckets) - self.min_probs = min_probs - - def schedule(self, t, batch_size): - if t is not None: - raise Exception("AdaptiveTrainSchedule doesn't support passing timesteps: t") - norm_probs = ((self.bucket_probs+self.min_probs) / (self.bucket_probs+self.min_probs).sum()) - buckets = torch.multinomial(norm_probs, batch_size, replacement=True) - ranges = self.bucket_ranges[buckets] - logSNR = torch.rand(batch_size) * (ranges[:, 1]-ranges[:, 0]) + ranges[:, 0] - return logSNR - - def update_buckets(self, logSNR, loss, beta=0.99): - range_mtx = self.bucket_ranges.unsqueeze(0).expand(logSNR.size(0), -1, -1).to(logSNR.device) - range_mask = (range_mtx[:, :, 0] <= logSNR[:, None]) * (range_mtx[:, :, 1] > logSNR[:, None]).float() - range_idx = range_mask.argmax(-1).cpu() - self.bucket_probs[range_idx] = self.bucket_probs[range_idx] * beta + loss.detach().cpu() * (1-beta) - -class InterpolatedSchedule(BaseSchedule): - def setup(self, scheduler1, scheduler2, shifts=[1.0, 1.0]): - self.scheduler1 = scheduler1 - self.scheduler2 = scheduler2 - self.shifts = shifts - - def schedule(self, t, batch_size): - if t is None: - t = 1-torch.rand(batch_size) - t = t.clamp(1e-7, 1-1e-7) # avoid infinities multiplied by 0 which cause nan - low_logSNR = self.scheduler1(t, shift=self.shifts[0]) - high_logSNR = self.scheduler2(t, shift=self.shifts[1]) - return low_logSNR * t + high_logSNR * (1-t) - diff --git a/gdf/targets.py b/gdf/targets.py deleted file mode 100644 index 115062b6001f93082fa836e1f3742723e5972efe..0000000000000000000000000000000000000000 --- a/gdf/targets.py +++ /dev/null @@ -1,46 +0,0 @@ -class EpsilonTarget(): - def __call__(self, x0, epsilon, logSNR, a, b): - return epsilon - - def x0(self, noised, pred, logSNR, a, b): - return (noised - pred * b) / a - - def epsilon(self, noised, pred, logSNR, a, b): - return pred - def noise_givenx0_noised(self, x0, noised , logSNR, a, b): - return (noised - a * x0) / b - def xt(self, x0, noise, logSNR, a, b): - - return x0 * a + noise*b -class X0Target(): - def __call__(self, x0, epsilon, logSNR, a, b): - return x0 - - def x0(self, noised, pred, logSNR, a, b): - return pred - - def epsilon(self, noised, pred, logSNR, a, b): - return (noised - pred * a) / b - -class VTarget(): - def __call__(self, x0, epsilon, logSNR, a, b): - return a * epsilon - b * x0 - - def x0(self, noised, pred, logSNR, a, b): - squared_sum = a**2 + b**2 - return a/squared_sum * noised - b/squared_sum * pred - - def epsilon(self, noised, pred, logSNR, a, b): - squared_sum = a**2 + b**2 - return b/squared_sum * noised + a/squared_sum * pred - -class RectifiedFlowsTarget(): - def __call__(self, x0, epsilon, logSNR, a, b): - return epsilon - x0 - - def x0(self, noised, pred, logSNR, a, b): - return noised - pred * b - - def epsilon(self, noised, pred, logSNR, a, b): - return noised + pred * a - \ No newline at end of file diff --git a/inference/__init__.py b/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/inference/test_controlnet.py b/inference/test_controlnet.py deleted file mode 100644 index 250578262d2a118ece8b5a706aba1cd8115c62f5..0000000000000000000000000000000000000000 --- a/inference/test_controlnet.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import yaml -import torch -import torchvision -from tqdm import tqdm -import sys -sys.path.append(os.path.abspath('./')) - -from inference.utils import * -from core.utils import load_or_fail -from train import WurstCore_control_lrguide, WurstCoreB -from PIL import Image -from core.utils import load_or_fail -import math -import argparse -import time -import random -import numpy as np -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( '--height', type=int, default=3840, help='image height') - parser.add_argument('--width', type=int, default=2160, help='image width') - parser.add_argument('--control_weight', type=float, default=0.70, help='[ 0.3, 0.8]') - parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') - parser.add_argument('--seed', type=int, default=123, help='random seed') - parser.add_argument('--config_c', type=str, - default='configs/training/cfg_control_lr.yaml' ,help='config file for stage c, latent generation') - parser.add_argument('--config_b', type=str, - default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') - parser.add_argument( '--prompt', type=str, - default='A peaceful lake surrounded by mountain, white cloud in the sky, high quality,', help='text prompt') - parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') - parser.add_argument( '--output_dir', type=str, default='figures/controlnet_results/', help='output directory for generated image') - parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') - parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') - parser.add_argument( '--canny_source_url', type=str, default="figures/California_000490.jpg", help='image used to extract canny edge map') - - args = parser.parse_args() - return args - - -if __name__ == "__main__": - - args = parse_args() - width = args.width - height = args.height - torch.manual_seed(args.seed) - random.seed(args.seed) - np.random.seed(args.seed) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float - - - # SETUP STAGE C - with open(args.config_c, "r", encoding="utf-8") as file: - loaded_config = yaml.safe_load(file) - core = WurstCore_control_lrguide(config_dict=loaded_config, device=device, training=False) - - # SETUP STAGE B - with open(args.config_b, "r", encoding="utf-8") as file: - config_file_b = yaml.safe_load(file) - - core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) - - extras = core.setup_extras_pre() - models = core.setup_models(extras) - models.generator.eval().requires_grad_(False) - print("CONTROLNET READY") - - extras_b = core_b.setup_extras_pre() - models_b = core_b.setup_models(extras_b, skip_clip=True) - models_b = WurstCoreB.Models( - **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} - ) - models_b.generator.eval().requires_grad_(False) - print("STAGE B READY") - - batch_size = 1 - save_dir = args.output_dir - url = args.canny_source_url - images = resize_image(Image.open(url).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1) - batch = {'images': images} - - - - - - - cnet_multiplier = args.control_weight # 0.8 0.6 0.3 control strength - caption_list = [args.prompt] * args.num_image - height_lr, width_lr = get_target_lr_size(height / width, std_size=32) - stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) - stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) - - - - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - - sdd = torch.load(args.pretrained_path, map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - models.train_norm.load_state_dict(collect_sd, strict=True) - - - - - models.controlnet.load_state_dict(load_or_fail(core.config.controlnet_checkpoint_path), strict=True) - # Stage C Parameters - extras.sampling_configs['cfg'] = 1 - extras.sampling_configs['shift'] = 2 - extras.sampling_configs['timesteps'] = 20 - extras.sampling_configs['t_start'] = 1.0 - - # Stage B Parameters - extras_b.sampling_configs['cfg'] = 1.1 - extras_b.sampling_configs['shift'] = 1 - extras_b.sampling_configs['timesteps'] = 10 - extras_b.sampling_configs['t_start'] = 1.0 - - # PREPARE CONDITIONS - - - - - for out_cnt, caption in enumerate(caption_list): - with torch.no_grad(): - - batch['captions'] = [caption + ' high quality'] * batch_size - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - cnet, cnet_input = core.get_cnet(batch, models, extras) - cnet_uncond = cnet - conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet] - unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond] - edge_images = show_images(cnet_input) - models.generator.cuda() - for idx, img in enumerate(edge_images): - img.save(os.path.join(save_dir, f"edge_{url.split('/')[-1]}")) - - - print('STAGE C GENERATION***************************') - with torch.cuda.amp.autocast(dtype=dtype): - sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions, unconditions) - models.generator.cpu() - torch.cuda.empty_cache() - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - - conditions_b['effnet'] = sampled_c - unconditions_b['effnet'] = torch.zeros_like(sampled_c) - print('STAGE B + A DECODING***************************') - with torch.cuda.amp.autocast(dtype=dtype): - sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) - - torch.cuda.empty_cache() - imgs = show_images(sampled) - - for idx, img in enumerate(imgs): - img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(out_cnt).zfill(5) + '.jpg')) - print('finished! Results at ', save_dir ) diff --git a/inference/test_personalized.py b/inference/test_personalized.py deleted file mode 100644 index 840d52d0ef3b026e73c34f715b7b18ec3537e62a..0000000000000000000000000000000000000000 --- a/inference/test_personalized.py +++ /dev/null @@ -1,180 +0,0 @@ - -import os -import yaml -import torch -from tqdm import tqdm -import sys -sys.path.append(os.path.abspath('./')) -from inference.utils import * -from train import WurstCoreB -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from train import WurstCore_personalized as WurstCoreC -import torch.nn.functional as F -import numpy as np -import random -import math -import argparse - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( '--height', type=int, default=3072, help='image height') - parser.add_argument('--width', type=int, default=4096, help='image width') - parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') - parser.add_argument('--seed', type=int, default=23, help='random seed') - parser.add_argument('--config_c', type=str, - default="configs/training/lora_personalization.yaml" ,help='config file for stage c, latent generation') - parser.add_argument('--config_b', type=str, - default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') - parser.add_argument( '--prompt', type=str, - default='A photo of cat [roubaobao] with sunglasses, Time Square in the background, high quality, detail rich, 8k', help='text prompt') - parser.add_argument( '--num_image', type=int, default=4, help='how many images generated') - parser.add_argument( '--output_dir', type=str, default='figures/personalized/', help='output directory for generated image') - parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') - parser.add_argument( '--pretrained_path_lora', type=str, default='models/lora_cat.safetensors',help='pretrained path of personalized lora parameter') - parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') - args = parser.parse_args() - return args - -if __name__ == "__main__": - args = parse_args() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - torch.manual_seed(args.seed) - random.seed(args.seed) - np.random.seed(args.seed) - dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float - - - # SETUP STAGE C - with open(args.config_c, "r", encoding="utf-8") as file: - loaded_config = yaml.safe_load(file) - core = WurstCoreC(config_dict=loaded_config, device=device, training=False) - - # SETUP STAGE B - with open(args.config_b, "r", encoding="utf-8") as file: - config_file_b = yaml.safe_load(file) - core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) - - extras = core.setup_extras_pre() - models = core.setup_models(extras) - models.generator.eval().requires_grad_(False) - print("STAGE C READY") - - extras_b = core_b.setup_extras_pre() - models_b = core_b.setup_models(extras_b, skip_clip=True) - models_b = WurstCoreB.Models( - **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} - ) - models_b.generator.bfloat16().eval().requires_grad_(False) - print("STAGE B READY") - - - batch_size = 1 - captions = [args.prompt] * args.num_image - height, width = args.height, args.width - save_dir = args.output_dir - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - - pretrained_pth = args.pretrained_path - sdd = torch.load(pretrained_pth, map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - - models.train_norm.load_state_dict(collect_sd) - - - pretrained_pth_lora = args.pretrained_path_lora - sdd = torch.load(pretrained_pth_lora, map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - - models.train_lora.load_state_dict(collect_sd) - - - models.generator.eval() - models.train_norm.eval() - - - height_lr, width_lr = get_target_lr_size(height / width, std_size=32) - stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) - stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) - - # Stage C Parameters - - extras.sampling_configs['cfg'] = 4 - extras.sampling_configs['shift'] = 1 - extras.sampling_configs['timesteps'] = 20 - extras.sampling_configs['t_start'] = 1.0 - extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) - - - - # Stage B Parameters - - extras_b.sampling_configs['cfg'] = 1.1 - extras_b.sampling_configs['shift'] = 1 - extras_b.sampling_configs['timesteps'] = 10 - extras_b.sampling_configs['t_start'] = 1.0 - - - for cnt, caption in enumerate(captions): - - batch = {'captions': [caption] * batch_size} - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - - - - - for cnt, caption in enumerate(captions): - - - batch = {'captions': [caption] * batch_size} - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - - - with torch.no_grad(): - - - models.generator.cuda() - print('STAGE C GENERATION***************************') - with torch.cuda.amp.autocast(dtype=dtype): - sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) - - - - models.generator.cpu() - torch.cuda.empty_cache() - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - conditions_b['effnet'] = sampled_c - unconditions_b['effnet'] = torch.zeros_like(sampled_c) - print('STAGE B + A DECODING***************************') - - with torch.cuda.amp.autocast(dtype=dtype): - sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) - - torch.cuda.empty_cache() - imgs = show_images(sampled) - for idx, img in enumerate(imgs): - print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) - img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) - - - print('finished! Results at ', save_dir ) - - - diff --git a/inference/test_t2i.py b/inference/test_t2i.py deleted file mode 100644 index 3478f95e4c706d88a8c73688ed4e990adc9ea8d4..0000000000000000000000000000000000000000 --- a/inference/test_t2i.py +++ /dev/null @@ -1,170 +0,0 @@ - -import os -import yaml -import torch -from tqdm import tqdm -import sys -sys.path.append(os.path.abspath('./')) -from inference.utils import * -from core.utils import load_or_fail -from train import WurstCoreB -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from train import WurstCore_t2i as WurstCoreC -import torch.nn.functional as F -from core.utils import load_or_fail -import numpy as np -import random -import math -import argparse -from einops import rearrange -import math -#inrfft_3b_strc_WurstCore -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( '--height', type=int, default=2560, help='image height') - parser.add_argument('--width', type=int, default=5120, help='image width') - parser.add_argument('--seed', type=int, default=123, help='random seed') - parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') - parser.add_argument('--config_c', type=str, - default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation') - parser.add_argument('--config_b', type=str, - default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') - parser.add_argument( '--prompt', type=str, - default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') - parser.add_argument( '--num_image', type=int, default=10, help='how many images generated') - parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') - parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') - parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') - args = parser.parse_args() - return args - - - -if __name__ == "__main__": - - args = parse_args() - print(args) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - print(device) - torch.manual_seed(args.seed) - random.seed(args.seed) - np.random.seed(args.seed) - dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float - #gdf = gdf_refine( - # schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - # input_scaler=VPScaler(), target=EpsilonTarget(), - # noise_cond=CosineTNoiseCond(), - # loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - # ) - # SETUP STAGE C - config_file = args.config_c - with open(config_file, "r", encoding="utf-8") as file: - loaded_config = yaml.safe_load(file) - - core = WurstCoreC(config_dict=loaded_config, device=device, training=False) - - # SETUP STAGE B - config_file_b = args.config_b - with open(config_file_b, "r", encoding="utf-8") as file: - config_file_b = yaml.safe_load(file) - - core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) - - extras = core.setup_extras_pre() - models = core.setup_models(extras) - models.generator.eval().requires_grad_(False) - print("STAGE C READY") - - extras_b = core_b.setup_extras_pre() - models_b = core_b.setup_models(extras_b, skip_clip=True) - models_b = WurstCoreB.Models( - **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} - ) - models_b.generator.bfloat16().eval().requires_grad_(False) - print("STAGE B READY") - - captions = [args.prompt] * args.num_image - - - height, width = args.height, args.width - save_dir = args.output_dir - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - pretrained_path = args.pretrained_path - sdd = torch.load(pretrained_path, map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - - models.train_norm.load_state_dict(collect_sd) - - - models.generator.eval() - models.train_norm.eval() - - batch_size=1 - height_lr, width_lr = get_target_lr_size(height / width, std_size=32) - stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) - stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) - - # Stage C Parameters - extras.sampling_configs['cfg'] = 4 - extras.sampling_configs['shift'] = 1 - extras.sampling_configs['timesteps'] = 20 - extras.sampling_configs['t_start'] = 1.0 - extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) - - - - # Stage B Parameters - extras_b.sampling_configs['cfg'] = 1.1 - extras_b.sampling_configs['shift'] = 1 - extras_b.sampling_configs['timesteps'] = 10 - extras_b.sampling_configs['t_start'] = 1.0 - - - - - for cnt, caption in enumerate(captions): - - - batch = {'captions': [caption] * batch_size} - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - - - with torch.no_grad(): - - - models.generator.cuda() - print('STAGE C GENERATION***************************') - with torch.cuda.amp.autocast(dtype=dtype): - sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) - - - - models.generator.cpu() - torch.cuda.empty_cache() - - conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) - unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) - conditions_b['effnet'] = sampled_c - unconditions_b['effnet'] = torch.zeros_like(sampled_c) - print('STAGE B + A DECODING***************************') - - with torch.cuda.amp.autocast(dtype=dtype): - sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) - - torch.cuda.empty_cache() - imgs = show_images(sampled) - for idx, img in enumerate(imgs): - print(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg'), idx) - img.save(os.path.join(save_dir, args.prompt[:20]+'_' + str(cnt).zfill(5) + '.jpg')) - - - print('finished! Results at ', save_dir ) diff --git a/inference/utils.py b/inference/utils.py deleted file mode 100644 index ab5af277069ec7803d53ff8f5fa29bed41fde29b..0000000000000000000000000000000000000000 --- a/inference/utils.py +++ /dev/null @@ -1,131 +0,0 @@ -import PIL -import torch -import requests -import torchvision -from math import ceil -from io import BytesIO -import matplotlib.pyplot as plt -import torchvision.transforms.functional as F -import math -from tqdm import tqdm -def download_image(url): - return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB") - - -def resize_image(image, size=768): - tensor_image = F.to_tensor(image) - resized_image = F.resize(tensor_image, size, antialias=True) - return resized_image - - -def downscale_images(images, factor=3/4): - scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32) - scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) - return scaled_image - - - -def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): - resolution_multiple = 42.67 - latent_height = ceil(height / compression_factor_b) - latent_width = ceil(width / compression_factor_b) - stage_c_latent_shape = (batch_size, 16, latent_height, latent_width) - - latent_height = ceil(height / compression_factor_a) - latent_width = ceil(width / compression_factor_a) - stage_b_latent_shape = (batch_size, 4, latent_height, latent_width) - - return stage_c_latent_shape, stage_b_latent_shape - - -def get_views(H, W, window_size=64, stride=16): - ''' - - H, W: height and width of the latent - ''' - num_blocks_height = (H - window_size) // stride + 1 - num_blocks_width = (W - window_size) // stride + 1 - total_num_blocks = int(num_blocks_height * num_blocks_width) - views = [] - for i in range(total_num_blocks): - h_start = int((i // num_blocks_width) * stride) - h_end = h_start + window_size - w_start = int((i % num_blocks_width) * stride) - w_end = w_start + window_size - views.append((h_start, h_end, w_start, w_end)) - return views - - - -def show_images(images, rows=None, cols=None, **kwargs): - if images.size(1) == 1: - images = images.repeat(1, 3, 1, 1) - elif images.size(1) > 3: - images = images[:, :3] - - if rows is None: - rows = 1 - if cols is None: - cols = images.size(0) // rows - - _, _, h, w = images.shape - - imgs = [] - for i, img in enumerate(images): - imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1))) - - return imgs - - - -def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \ - stage_a_tiled=False, num_instance=4, patch_size=256, stride=24): - - - sampling_b = extras_b.gdf.sample( - models_b.generator.half(), conditions_b, bshape, - unconditions_b, device=device, - **extras_b.sampling_configs, - ) - models_b.generator.cuda() - for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']): - sampled_b = sampled_b - models_b.generator.cpu() - torch.cuda.empty_cache() - if stage_a_tiled: - with torch.cuda.amp.autocast(dtype=torch.float16): - padding = (stride*2, stride*2, stride*2, stride*2) - sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect') - count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) - sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device) - views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride) - - for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))): - - sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float() - count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1 - sampled /= count - sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2] - else: - - sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled) - - return sampled.float() - - -def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None): - if conditions is None: - conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - if unconditions is None: - unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - sampling_c = extras.gdf.sample( - models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr, - unconditions, device=device, **extras.sampling_configs, - ) - for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])): - sampled_c = sampled_c - return sampled_c - -def get_target_lr_size(ratio, std_size=24): - w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) - return (h * 32 , w *32 ) - diff --git a/models/models_checklist.txt b/models/models_checklist.txt deleted file mode 100644 index 2fdec27a72db473c51893abc64826514b1d9d065..0000000000000000000000000000000000000000 --- a/models/models_checklist.txt +++ /dev/null @@ -1,7 +0,0 @@ -https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors -https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors -https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors -https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors -https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors -https://huggingface.co/roubaofeipi/UltraPixel/blob/main/ultrapixel_t2i.safetensors -https://huggingface.co/roubaofeipi/UltraPixel/blob/main/lora_cat.safetensors (only required for personalization) \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py deleted file mode 100644 index a6fcf5aa2a39061c3f4f82dde6ff063411223cb3..0000000000000000000000000000000000000000 --- a/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .effnet import EfficientNetEncoder -from .stage_c import StageC -from .stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from .previewer import Previewer -from .controlnet import ControlNet, ControlNetDeliverer -from . import controlnet as controlnet_filters diff --git a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc b/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc deleted file mode 100644 index 8c74bb92cb0db0876acda8aa3d102141526fd428..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/face_id/__pycache__/arcface.cpython-310.pyc and /dev/null differ diff --git a/modules/cnet_modules/face_id/arcface.py b/modules/cnet_modules/face_id/arcface.py deleted file mode 100644 index 64e918bb90437f6f193a7ec384bea1fcd73c7abb..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/face_id/arcface.py +++ /dev/null @@ -1,276 +0,0 @@ -import numpy as np -import onnx, onnx2torch, cv2 -import torch -from insightface.utils import face_align - - -class ArcFaceRecognizer: - def __init__(self, model_file=None, device='cpu', dtype=torch.float32): - assert model_file is not None - self.model_file = model_file - - self.device = device - self.dtype = dtype - self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) - for param in self.model.parameters(): - param.requires_grad = False - self.model.eval() - - self.input_mean = 127.5 - self.input_std = 127.5 - self.input_size = (112, 112) - self.input_shape = ['None', 3, 112, 112] - - def get(self, img, face): - aimg = face_align.norm_crop(img, landmark=face.kps, image_size=self.input_size[0]) - face.embedding = self.get_feat(aimg).flatten() - return face.embedding - - def compute_sim(self, feat1, feat2): - from numpy.linalg import norm - feat1 = feat1.ravel() - feat2 = feat2.ravel() - sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2)) - return sim - - def get_feat(self, imgs): - if not isinstance(imgs, list): - imgs = [imgs] - input_size = self.input_size - - blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size, - (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - - blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) - net_out = self.model(blob_torch) - return net_out[0].float().cpu() - - -def distance2bbox(points, distance, max_shape=None): - """Decode distance prediction to bounding box. - - Args: - points (Tensor): Shape (n, 2), [x, y]. - distance (Tensor): Distance from the given point to 4 - boundaries (left, top, right, bottom). - max_shape (tuple): Shape of the image. - - Returns: - Tensor: Decoded bboxes. - """ - x1 = points[:, 0] - distance[:, 0] - y1 = points[:, 1] - distance[:, 1] - x2 = points[:, 0] + distance[:, 2] - y2 = points[:, 1] + distance[:, 3] - if max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1]) - y1 = y1.clamp(min=0, max=max_shape[0]) - x2 = x2.clamp(min=0, max=max_shape[1]) - y2 = y2.clamp(min=0, max=max_shape[0]) - return np.stack([x1, y1, x2, y2], axis=-1) - - -def distance2kps(points, distance, max_shape=None): - """Decode distance prediction to bounding box. - - Args: - points (Tensor): Shape (n, 2), [x, y]. - distance (Tensor): Distance from the given point to 4 - boundaries (left, top, right, bottom). - max_shape (tuple): Shape of the image. - - Returns: - Tensor: Decoded bboxes. - """ - preds = [] - for i in range(0, distance.shape[1], 2): - px = points[:, i % 2] + distance[:, i] - py = points[:, i % 2 + 1] + distance[:, i + 1] - if max_shape is not None: - px = px.clamp(min=0, max=max_shape[1]) - py = py.clamp(min=0, max=max_shape[0]) - preds.append(px) - preds.append(py) - return np.stack(preds, axis=-1) - - -class FaceDetector: - def __init__(self, model_file=None, dtype=torch.float32, device='cuda'): - self.model_file = model_file - self.taskname = 'detection' - self.center_cache = {} - self.nms_thresh = 0.4 - self.det_thresh = 0.5 - - self.device = device - self.dtype = dtype - self.model = onnx2torch.convert(onnx.load(model_file)).to(device=device, dtype=dtype) - for param in self.model.parameters(): - param.requires_grad = False - self.model.eval() - - input_shape = (320, 320) - self.input_size = input_shape - self.input_shape = input_shape - - self.input_mean = 127.5 - self.input_std = 128.0 - self._anchor_ratio = 1.0 - self._num_anchors = 1 - self.fmc = 3 - self._feat_stride_fpn = [8, 16, 32] - self._num_anchors = 2 - self.use_kps = True - - self.det_thresh = 0.5 - self.nms_thresh = 0.4 - - def forward(self, img, threshold): - scores_list = [] - bboxes_list = [] - kpss_list = [] - input_size = tuple(img.shape[0:2][::-1]) - blob = cv2.dnn.blobFromImage(img, 1.0 / self.input_std, input_size, - (self.input_mean, self.input_mean, self.input_mean), swapRB=True) - blob_torch = torch.tensor(blob).to(device=self.device, dtype=self.dtype) - net_outs_torch = self.model(blob_torch) - # print(list(map(lambda x: x.shape, net_outs_torch))) - net_outs = list(map(lambda x: x.float().cpu().numpy(), net_outs_torch)) - - input_height = blob.shape[2] - input_width = blob.shape[3] - fmc = self.fmc - for idx, stride in enumerate(self._feat_stride_fpn): - scores = net_outs[idx] - bbox_preds = net_outs[idx + fmc] - bbox_preds = bbox_preds * stride - if self.use_kps: - kps_preds = net_outs[idx + fmc * 2] * stride - height = input_height // stride - width = input_width // stride - K = height * width - key = (height, width, stride) - if key in self.center_cache: - anchor_centers = self.center_cache[key] - else: - # solution-1, c style: - # anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) - # for i in range(height): - # anchor_centers[i, :, 1] = i - # for i in range(width): - # anchor_centers[:, i, 0] = i - - # solution-2: - # ax = np.arange(width, dtype=np.float32) - # ay = np.arange(height, dtype=np.float32) - # xv, yv = np.meshgrid(np.arange(width), np.arange(height)) - # anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) - - # solution-3: - anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) - # print(anchor_centers.shape) - - anchor_centers = (anchor_centers * stride).reshape((-1, 2)) - if self._num_anchors > 1: - anchor_centers = np.stack([anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2)) - if len(self.center_cache) < 100: - self.center_cache[key] = anchor_centers - - pos_inds = np.where(scores >= threshold)[0] - bboxes = distance2bbox(anchor_centers, bbox_preds) - pos_scores = scores[pos_inds] - pos_bboxes = bboxes[pos_inds] - scores_list.append(pos_scores) - bboxes_list.append(pos_bboxes) - if self.use_kps: - kpss = distance2kps(anchor_centers, kps_preds) - # kpss = kps_preds - kpss = kpss.reshape((kpss.shape[0], -1, 2)) - pos_kpss = kpss[pos_inds] - kpss_list.append(pos_kpss) - return scores_list, bboxes_list, kpss_list - - def detect(self, img, input_size=None, max_num=0, metric='default'): - assert input_size is not None or self.input_size is not None - input_size = self.input_size if input_size is None else input_size - - im_ratio = float(img.shape[0]) / img.shape[1] - model_ratio = float(input_size[1]) / input_size[0] - if im_ratio > model_ratio: - new_height = input_size[1] - new_width = int(new_height / im_ratio) - else: - new_width = input_size[0] - new_height = int(new_width * im_ratio) - det_scale = float(new_height) / img.shape[0] - resized_img = cv2.resize(img, (new_width, new_height)) - det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8) - det_img[:new_height, :new_width, :] = resized_img - - scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) - - scores = np.vstack(scores_list) - scores_ravel = scores.ravel() - order = scores_ravel.argsort()[::-1] - bboxes = np.vstack(bboxes_list) / det_scale - if self.use_kps: - kpss = np.vstack(kpss_list) / det_scale - pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) - pre_det = pre_det[order, :] - keep = self.nms(pre_det) - det = pre_det[keep, :] - if self.use_kps: - kpss = kpss[order, :, :] - kpss = kpss[keep, :, :] - else: - kpss = None - if max_num > 0 and det.shape[0] > max_num: - area = (det[:, 2] - det[:, 0]) * (det[:, 3] - - det[:, 1]) - img_center = img.shape[0] // 2, img.shape[1] // 2 - offsets = np.vstack([ - (det[:, 0] + det[:, 2]) / 2 - img_center[1], - (det[:, 1] + det[:, 3]) / 2 - img_center[0] - ]) - offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) - if metric == 'max': - values = area - else: - values = area - offset_dist_squared * 2.0 # some extra weight on the centering - bindex = np.argsort( - values)[::-1] # some extra weight on the centering - bindex = bindex[0:max_num] - det = det[bindex, :] - if kpss is not None: - kpss = kpss[bindex, :] - return det, kpss - - def nms(self, dets): - thresh = self.nms_thresh - x1 = dets[:, 0] - y1 = dets[:, 1] - x2 = dets[:, 2] - y2 = dets[:, 3] - scores = dets[:, 4] - - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - keep = [] - while order.size > 0: - i = order[0] - keep.append(i) - xx1 = np.maximum(x1[i], x1[order[1:]]) - yy1 = np.maximum(y1[i], y1[order[1:]]) - xx2 = np.minimum(x2[i], x2[order[1:]]) - yy2 = np.minimum(y2[i], y2[order[1:]]) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas[i] + areas[order[1:]] - inter) - - inds = np.where(ovr <= thresh)[0] - order = order[inds + 1] - - return keep diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc deleted file mode 100644 index 8200104d6d66a1084685c76373c38d752ed9c3d4..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-310.pyc and /dev/null differ diff --git a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc b/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc deleted file mode 100644 index ca432e5c5eed7ba17fc6cafb06a3ebe16002f67e..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/inpainting/__pycache__/saliency_model.cpython-39.pyc and /dev/null differ diff --git a/modules/cnet_modules/inpainting/saliency_model.pt b/modules/cnet_modules/inpainting/saliency_model.pt deleted file mode 100644 index e1b02cc60b2999a8f9ff90557182e3dafab63db7..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/inpainting/saliency_model.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:225a602e1f2a5d159424be011a63b27d83b56343a4379a90710eca9a26bab920 -size 451123 diff --git a/modules/cnet_modules/inpainting/saliency_model.py b/modules/cnet_modules/inpainting/saliency_model.py deleted file mode 100644 index 82355a02baead47f50fe643e57b81f8caca78f79..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/inpainting/saliency_model.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import torchvision -from torch import nn -from PIL import Image -import numpy as np -import os - - -# MICRO RESNET -class ResBlock(nn.Module): - def __init__(self, channels): - super(ResBlock, self).__init__() - - self.resblock = nn.Sequential( - nn.ReflectionPad2d(1), - nn.Conv2d(channels, channels, kernel_size=3), - nn.InstanceNorm2d(channels, affine=True), - nn.ReLU(), - nn.ReflectionPad2d(1), - nn.Conv2d(channels, channels, kernel_size=3), - nn.InstanceNorm2d(channels, affine=True), - ) - - def forward(self, x): - out = self.resblock(x) - return out + x - - -class Upsample2d(nn.Module): - def __init__(self, scale_factor): - super(Upsample2d, self).__init__() - - self.interp = nn.functional.interpolate - self.scale_factor = scale_factor - - def forward(self, x): - x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') - return x - - -class MicroResNet(nn.Module): - def __init__(self): - super(MicroResNet, self).__init__() - - self.downsampler = nn.Sequential( - nn.ReflectionPad2d(4), - nn.Conv2d(3, 8, kernel_size=9, stride=4), - nn.InstanceNorm2d(8, affine=True), - nn.ReLU(), - nn.ReflectionPad2d(1), - nn.Conv2d(8, 16, kernel_size=3, stride=2), - nn.InstanceNorm2d(16, affine=True), - nn.ReLU(), - nn.ReflectionPad2d(1), - nn.Conv2d(16, 32, kernel_size=3, stride=2), - nn.InstanceNorm2d(32, affine=True), - nn.ReLU(), - ) - - self.residual = nn.Sequential( - ResBlock(32), - nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), - ResBlock(64), - ) - - self.segmentator = nn.Sequential( - nn.ReflectionPad2d(1), - nn.Conv2d(64, 16, kernel_size=3), - nn.InstanceNorm2d(16, affine=True), - nn.ReLU(), - Upsample2d(scale_factor=2), - nn.ReflectionPad2d(4), - nn.Conv2d(16, 1, kernel_size=9), - nn.Sigmoid() - ) - - def forward(self, x): - out = self.downsampler(x) - out = self.residual(out) - out = self.segmentator(out) - return out diff --git a/modules/cnet_modules/pidinet/__init__.py b/modules/cnet_modules/pidinet/__init__.py deleted file mode 100644 index a2b4625bf915cc6c4053b7d7861a22ff371bc641..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/pidinet/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Pidinet -# https://github.com/hellozhuo/pidinet - -import os -import torch -import numpy as np -from einops import rearrange -from .model import pidinet -from .util import annotator_ckpts_path, safe_step - - -class PidiNetDetector: - def __init__(self, device): - remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" - modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") - if not os.path.exists(modelpath): - from basicsr.utils.download_util import load_file_from_url - load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) - self.netNetwork = pidinet() - self.netNetwork.load_state_dict( - {k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) - self.netNetwork.to(device).eval().requires_grad_(False) - - def __call__(self, input_image): # , safe=False): - return self.netNetwork(input_image)[-1] - # assert input_image.ndim == 3 - # input_image = input_image[:, :, ::-1].copy() - # with torch.no_grad(): - # image_pidi = torch.from_numpy(input_image).float().cuda() - # image_pidi = image_pidi / 255.0 - # image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') - # edge = self.netNetwork(image_pidi)[-1] - - # if safe: - # edge = safe_step(edge) - # edge = (edge * 255.0).clip(0, 255).astype(np.uint8) - # return edge[0][0] diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 07fca0abb9c90b7b40746b4044c4000ae69e00c7..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 5a060aa2baa87a3670aa0bf8276e2f34bafe9451..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc deleted file mode 100644 index 2243c853d18e2a404ced3eb4ac6a95a7a9ee6874..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/model.cpython-310.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc deleted file mode 100644 index 7f70342fc64759bc7459abf0f7986ee3b7fd2126..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/model.cpython-39.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc deleted file mode 100644 index b2e7ab031924860f1262f4d44bf2eaf57ca78edd..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/util.cpython-310.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc b/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc deleted file mode 100644 index 4da8564d03f99caa7a45d9ccb1358cb282cd2711..0000000000000000000000000000000000000000 Binary files a/modules/cnet_modules/pidinet/__pycache__/util.cpython-39.pyc and /dev/null differ diff --git a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth b/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth deleted file mode 100644 index 1ceba1de87e7bb3c81961b80acbb3a106ca249c0..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/pidinet/ckpts/table5_pidinet.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:80860ac267258b5f27486e0ef152a211d0b08120f62aeb185a050acc30da486c -size 2871148 diff --git a/modules/cnet_modules/pidinet/model.py b/modules/cnet_modules/pidinet/model.py deleted file mode 100644 index 26644c6f6174c3b5407bd10c914045758cbadefe..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/pidinet/model.py +++ /dev/null @@ -1,654 +0,0 @@ -""" -Author: Zhuo Su, Wenzhe Liu -Date: Feb 18, 2021 -""" - -import math - -import cv2 -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -nets = { - 'baseline': { - 'layer0': 'cv', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'cv', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'cv', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'cv', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'c-v15': { - 'layer0': 'cd', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'cv', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'cv', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'cv', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'a-v15': { - 'layer0': 'ad', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'cv', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'cv', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'cv', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'r-v15': { - 'layer0': 'rd', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'cv', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'cv', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'cv', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'cvvv4': { - 'layer0': 'cd', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'cd', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'cd', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'cd', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'avvv4': { - 'layer0': 'ad', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'ad', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'ad', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'ad', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'rvvv4': { - 'layer0': 'rd', - 'layer1': 'cv', - 'layer2': 'cv', - 'layer3': 'cv', - 'layer4': 'rd', - 'layer5': 'cv', - 'layer6': 'cv', - 'layer7': 'cv', - 'layer8': 'rd', - 'layer9': 'cv', - 'layer10': 'cv', - 'layer11': 'cv', - 'layer12': 'rd', - 'layer13': 'cv', - 'layer14': 'cv', - 'layer15': 'cv', - }, - 'cccv4': { - 'layer0': 'cd', - 'layer1': 'cd', - 'layer2': 'cd', - 'layer3': 'cv', - 'layer4': 'cd', - 'layer5': 'cd', - 'layer6': 'cd', - 'layer7': 'cv', - 'layer8': 'cd', - 'layer9': 'cd', - 'layer10': 'cd', - 'layer11': 'cv', - 'layer12': 'cd', - 'layer13': 'cd', - 'layer14': 'cd', - 'layer15': 'cv', - }, - 'aaav4': { - 'layer0': 'ad', - 'layer1': 'ad', - 'layer2': 'ad', - 'layer3': 'cv', - 'layer4': 'ad', - 'layer5': 'ad', - 'layer6': 'ad', - 'layer7': 'cv', - 'layer8': 'ad', - 'layer9': 'ad', - 'layer10': 'ad', - 'layer11': 'cv', - 'layer12': 'ad', - 'layer13': 'ad', - 'layer14': 'ad', - 'layer15': 'cv', - }, - 'rrrv4': { - 'layer0': 'rd', - 'layer1': 'rd', - 'layer2': 'rd', - 'layer3': 'cv', - 'layer4': 'rd', - 'layer5': 'rd', - 'layer6': 'rd', - 'layer7': 'cv', - 'layer8': 'rd', - 'layer9': 'rd', - 'layer10': 'rd', - 'layer11': 'cv', - 'layer12': 'rd', - 'layer13': 'rd', - 'layer14': 'rd', - 'layer15': 'cv', - }, - 'c16': { - 'layer0': 'cd', - 'layer1': 'cd', - 'layer2': 'cd', - 'layer3': 'cd', - 'layer4': 'cd', - 'layer5': 'cd', - 'layer6': 'cd', - 'layer7': 'cd', - 'layer8': 'cd', - 'layer9': 'cd', - 'layer10': 'cd', - 'layer11': 'cd', - 'layer12': 'cd', - 'layer13': 'cd', - 'layer14': 'cd', - 'layer15': 'cd', - }, - 'a16': { - 'layer0': 'ad', - 'layer1': 'ad', - 'layer2': 'ad', - 'layer3': 'ad', - 'layer4': 'ad', - 'layer5': 'ad', - 'layer6': 'ad', - 'layer7': 'ad', - 'layer8': 'ad', - 'layer9': 'ad', - 'layer10': 'ad', - 'layer11': 'ad', - 'layer12': 'ad', - 'layer13': 'ad', - 'layer14': 'ad', - 'layer15': 'ad', - }, - 'r16': { - 'layer0': 'rd', - 'layer1': 'rd', - 'layer2': 'rd', - 'layer3': 'rd', - 'layer4': 'rd', - 'layer5': 'rd', - 'layer6': 'rd', - 'layer7': 'rd', - 'layer8': 'rd', - 'layer9': 'rd', - 'layer10': 'rd', - 'layer11': 'rd', - 'layer12': 'rd', - 'layer13': 'rd', - 'layer14': 'rd', - 'layer15': 'rd', - }, - 'carv4': { - 'layer0': 'cd', - 'layer1': 'ad', - 'layer2': 'rd', - 'layer3': 'cv', - 'layer4': 'cd', - 'layer5': 'ad', - 'layer6': 'rd', - 'layer7': 'cv', - 'layer8': 'cd', - 'layer9': 'ad', - 'layer10': 'rd', - 'layer11': 'cv', - 'layer12': 'cd', - 'layer13': 'ad', - 'layer14': 'rd', - 'layer15': 'cv', - }, -} - - -def createConvFunc(op_type): - assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) - if op_type == 'cv': - return F.conv2d - - if op_type == 'cd': - def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): - assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' - assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' - assert padding == dilation, 'padding for cd_conv set wrong' - - weights_c = weights.sum(dim=[2, 3], keepdim=True) - yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) - y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - return y - yc - - return func - elif op_type == 'ad': - def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): - assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' - assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' - assert padding == dilation, 'padding for ad_conv set wrong' - - shape = weights.shape - weights = weights.view(shape[0], shape[1], -1) - weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise - y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - return y - - return func - elif op_type == 'rd': - def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): - assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' - assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' - padding = 2 * dilation - - shape = weights.shape - if weights.is_cuda: - buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) - else: - buffer = torch.zeros(shape[0], shape[1], 5 * 5) - weights = weights.view(shape[0], shape[1], -1) - buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] - buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] - buffer[:, :, 12] = 0 - buffer = buffer.view(shape[0], shape[1], 5, 5) - y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) - return y - - return func - else: - print('impossible to be here unless you force that') - return None - - -class Conv2d(nn.Module): - def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, - bias=False): - super(Conv2d, self).__init__() - if in_channels % groups != 0: - raise ValueError('in_channels must be divisible by groups') - if out_channels % groups != 0: - raise ValueError('out_channels must be divisible by groups') - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.groups = groups - self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.reset_parameters() - self.pdc = pdc - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input): - - return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - - -class CSAM(nn.Module): - """ - Compact Spatial Attention Module - """ - - def __init__(self, channels): - super(CSAM, self).__init__() - - mid_channels = 4 - self.relu1 = nn.ReLU() - self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) - self.sigmoid = nn.Sigmoid() - nn.init.constant_(self.conv1.bias, 0) - - def forward(self, x): - y = self.relu1(x) - y = self.conv1(y) - y = self.conv2(y) - y = self.sigmoid(y) - - return x * y - - -class CDCM(nn.Module): - """ - Compact Dilation Convolution based Module - """ - - def __init__(self, in_channels, out_channels): - super(CDCM, self).__init__() - - self.relu1 = nn.ReLU() - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) - self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) - self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) - self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) - self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) - nn.init.constant_(self.conv1.bias, 0) - - def forward(self, x): - x = self.relu1(x) - x = self.conv1(x) - x1 = self.conv2_1(x) - x2 = self.conv2_2(x) - x3 = self.conv2_3(x) - x4 = self.conv2_4(x) - return x1 + x2 + x3 + x4 - - -class MapReduce(nn.Module): - """ - Reduce feature maps into a single edge map - """ - - def __init__(self, channels): - super(MapReduce, self).__init__() - self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) - nn.init.constant_(self.conv.bias, 0) - - def forward(self, x): - return self.conv(x) - - -class PDCBlock(nn.Module): - def __init__(self, pdc, inplane, ouplane, stride=1): - super(PDCBlock, self).__init__() - self.stride = stride - - self.stride = stride - if self.stride > 1: - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) - self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) - self.relu2 = nn.ReLU() - self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) - - def forward(self, x): - if self.stride > 1: - x = self.pool(x) - y = self.conv1(x) - y = self.relu2(y) - y = self.conv2(y) - if self.stride > 1: - x = self.shortcut(x) - y = y + x - return y - - -class PDCBlock_converted(nn.Module): - """ - CPDC, APDC can be converted to vanilla 3x3 convolution - RPDC can be converted to vanilla 5x5 convolution - """ - - def __init__(self, pdc, inplane, ouplane, stride=1): - super(PDCBlock_converted, self).__init__() - self.stride = stride - - if self.stride > 1: - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) - if pdc == 'rd': - self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) - else: - self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) - self.relu2 = nn.ReLU() - self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) - - def forward(self, x): - if self.stride > 1: - x = self.pool(x) - y = self.conv1(x) - y = self.relu2(y) - y = self.conv2(y) - if self.stride > 1: - x = self.shortcut(x) - y = y + x - return y - - -class PiDiNet(nn.Module): - def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): - super(PiDiNet, self).__init__() - self.sa = sa - if dil is not None: - assert isinstance(dil, int), 'dil should be an int' - self.dil = dil - - self.fuseplanes = [] - - self.inplane = inplane - if convert: - if pdcs[0] == 'rd': - init_kernel_size = 5 - init_padding = 2 - else: - init_kernel_size = 3 - init_padding = 1 - self.init_block = nn.Conv2d(3, self.inplane, - kernel_size=init_kernel_size, padding=init_padding, bias=False) - block_class = PDCBlock_converted - else: - self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) - block_class = PDCBlock - - self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) - self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) - self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) - self.fuseplanes.append(self.inplane) # C - - inplane = self.inplane - self.inplane = self.inplane * 2 - self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) - self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) - self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) - self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) - self.fuseplanes.append(self.inplane) # 2C - - inplane = self.inplane - self.inplane = self.inplane * 2 - self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) - self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) - self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) - self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) - self.fuseplanes.append(self.inplane) # 4C - - self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) - self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) - self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) - self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) - self.fuseplanes.append(self.inplane) # 4C - - self.conv_reduces = nn.ModuleList() - if self.sa and self.dil is not None: - self.attentions = nn.ModuleList() - self.dilations = nn.ModuleList() - for i in range(4): - self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) - self.attentions.append(CSAM(self.dil)) - self.conv_reduces.append(MapReduce(self.dil)) - elif self.sa: - self.attentions = nn.ModuleList() - for i in range(4): - self.attentions.append(CSAM(self.fuseplanes[i])) - self.conv_reduces.append(MapReduce(self.fuseplanes[i])) - elif self.dil is not None: - self.dilations = nn.ModuleList() - for i in range(4): - self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) - self.conv_reduces.append(MapReduce(self.dil)) - else: - for i in range(4): - self.conv_reduces.append(MapReduce(self.fuseplanes[i])) - - self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias - nn.init.constant_(self.classifier.weight, 0.25) - nn.init.constant_(self.classifier.bias, 0) - - # print('initialization done') - - def get_weights(self): - conv_weights = [] - bn_weights = [] - relu_weights = [] - for pname, p in self.named_parameters(): - if 'bn' in pname: - bn_weights.append(p) - elif 'relu' in pname: - relu_weights.append(p) - else: - conv_weights.append(p) - - return conv_weights, bn_weights, relu_weights - - def forward(self, x): - H, W = x.size()[2:] - - x = self.init_block(x) - - x1 = self.block1_1(x) - x1 = self.block1_2(x1) - x1 = self.block1_3(x1) - - x2 = self.block2_1(x1) - x2 = self.block2_2(x2) - x2 = self.block2_3(x2) - x2 = self.block2_4(x2) - - x3 = self.block3_1(x2) - x3 = self.block3_2(x3) - x3 = self.block3_3(x3) - x3 = self.block3_4(x3) - - x4 = self.block4_1(x3) - x4 = self.block4_2(x4) - x4 = self.block4_3(x4) - x4 = self.block4_4(x4) - - x_fuses = [] - if self.sa and self.dil is not None: - for i, xi in enumerate([x1, x2, x3, x4]): - x_fuses.append(self.attentions[i](self.dilations[i](xi))) - elif self.sa: - for i, xi in enumerate([x1, x2, x3, x4]): - x_fuses.append(self.attentions[i](xi)) - elif self.dil is not None: - for i, xi in enumerate([x1, x2, x3, x4]): - x_fuses.append(self.dilations[i](xi)) - else: - x_fuses = [x1, x2, x3, x4] - - e1 = self.conv_reduces[0](x_fuses[0]) - e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) - - e2 = self.conv_reduces[1](x_fuses[1]) - e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) - - e3 = self.conv_reduces[2](x_fuses[2]) - e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) - - e4 = self.conv_reduces[3](x_fuses[3]) - e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) - - outputs = [e1, e2, e3, e4] - - output = self.classifier(torch.cat(outputs, dim=1)) - # if not self.training: - # return torch.sigmoid(output) - - outputs.append(output) - outputs = [torch.sigmoid(r) for r in outputs] - return outputs - - -def config_model(model): - model_options = list(nets.keys()) - assert model in model_options, \ - 'unrecognized model, please choose from %s' % str(model_options) - - # print(str(nets[model])) - - pdcs = [] - for i in range(16): - layer_name = 'layer%d' % i - op = nets[model][layer_name] - pdcs.append(createConvFunc(op)) - - return pdcs - - -def pidinet(): - pdcs = config_model('carv4') - dil = 24 # if args.dil else None - return PiDiNet(60, pdcs, dil=dil, sa=True) diff --git a/modules/cnet_modules/pidinet/util.py b/modules/cnet_modules/pidinet/util.py deleted file mode 100644 index aec00770c7706f95abf3a0b9b02dbe3232930596..0000000000000000000000000000000000000000 --- a/modules/cnet_modules/pidinet/util.py +++ /dev/null @@ -1,97 +0,0 @@ -import random - -import numpy as np -import cv2 -import os - -annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') - - -def HWC3(x): - assert x.dtype == np.uint8 - if x.ndim == 2: - x = x[:, :, None] - assert x.ndim == 3 - H, W, C = x.shape - assert C == 1 or C == 3 or C == 4 - if C == 3: - return x - if C == 1: - return np.concatenate([x, x, x], axis=2) - if C == 4: - color = x[:, :, 0:3].astype(np.float32) - alpha = x[:, :, 3:4].astype(np.float32) / 255.0 - y = color * alpha + 255.0 * (1.0 - alpha) - y = y.clip(0, 255).astype(np.uint8) - return y - - -def resize_image(input_image, resolution): - H, W, C = input_image.shape - H = float(H) - W = float(W) - k = float(resolution) / min(H, W) - H *= k - W *= k - H = int(np.round(H / 64.0)) * 64 - W = int(np.round(W / 64.0)) * 64 - img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) - return img - - -def nms(x, t, s): - x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) - - f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) - f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) - f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) - f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) - - y = np.zeros_like(x) - - for f in [f1, f2, f3, f4]: - np.putmask(y, cv2.dilate(x, kernel=f) == x, x) - - z = np.zeros_like(y, dtype=np.uint8) - z[y > t] = 255 - return z - - -def make_noise_disk(H, W, C, F): - noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) - noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) - noise = noise[F: F + H, F: F + W] - noise -= np.min(noise) - noise /= np.max(noise) - if C == 1: - noise = noise[:, :, None] - return noise - - -def min_max_norm(x): - x -= np.min(x) - x /= np.maximum(np.max(x), 1e-5) - return x - - -def safe_step(x, step=2): - y = x.astype(np.float32) * float(step + 1) - y = y.astype(np.int32).astype(np.float32) / float(step) - return y - - -def img2mask(img, H, W, low=10, high=90): - assert img.ndim == 3 or img.ndim == 2 - assert img.dtype == np.uint8 - - if img.ndim == 3: - y = img[:, :, random.randrange(0, img.shape[2])] - else: - y = img - - y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) - - if random.uniform(0, 1) < 0.5: - y = 255 - y - - return y < np.percentile(y, random.randrange(low, high)) diff --git a/modules/common.py b/modules/common.py deleted file mode 100644 index 5e4ad71649f60f2dd38947c9ebc23bc51db2b544..0000000000000000000000000000000000000000 --- a/modules/common.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from einops import rearrange -import torch.fft as fft -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None - -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None - - - -class Attention2D(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn: - #print('in line 23 algong self att ', kv.shape, x.shape) - kv = torch.cat([x, kv], dim=1) - #if x.shape[1] >= 72 * 72: - # x = x * math.sqrt(math.log(64*64, 24*24)) - - x = self.attn(x, kv, kv, need_weights=False)[0] - x = x.permute(0, 2, 1).view(*orig_shape) - return x - - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - -class GlobalResponseNorm(nn.Module): - "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -class ResBlock(nn.Module): - def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): - super().__init__() - self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - # self.depthwise = SAMBlock(c, num_heads, expansion) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - Linear(c + c_skip, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - Linear(c * 4, c) - ) - - def forward(self, x, x_skip=None): - x_res = x - x = self.norm(self.depthwise(x)) - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + x_res - - -class AttnBlock(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - - def forward(self, x, kv): - kv = self.kv_mapper(kv) - res = self.attention(self.norm(x), kv, self_attn=self.self_attn) - - #print(torch.unique(res), torch.unique(x), self.self_attn) - #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) - x = x + res - - return x - -class FeedForwardBlock(nn.Module): - def __init__(self, c, dropout=0.0): - super().__init__() - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - Linear(c, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - Linear(c * 4, c) - ) - - def forward(self, x): - x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x - - -class TimestepBlock(nn.Module): - def __init__(self, c, c_timestep, conds=['sca']): - super().__init__() - self.mapper = Linear(c_timestep, c * 2) - self.conds = conds - for cname in conds: - setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) - - def forward(self, x, t): - t = t.chunk(len(self.conds) + 1, dim=1) - a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) - for i, c in enumerate(self.conds): - ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) - a, b = a + ac, b + bc - return x * (1 + a) + b diff --git a/modules/common_ckpt.py b/modules/common_ckpt.py deleted file mode 100644 index f64cf11790bdd2a83ca0744629336d81464b3ed0..0000000000000000000000000000000000000000 --- a/modules/common_ckpt.py +++ /dev/null @@ -1,360 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from einops import rearrange -from modules.speed_util import checkpoint -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None - -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None - -class AttnBlock_lrfuse_backup(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - self.fuse_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - self.use_checkpoint = use_checkpoint - - def forward(self, hr, lr): - return checkpoint(self._forward, (hr, lr), self.paramters(), self.use_checkpoint) - def _forward(self, hr, lr): - res = hr - hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) - lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr - - lr_fuse = self.fuse_mapper(rearrange(lr_fuse, 'b c h w -> b (h w ) c')) - hr = self.attention(self.norm(res), lr_fuse, self_attn=False) + res - return hr - - -class AttnBlock_lrfuse(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, kernel_size=3, use_checkpoint=True): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - - - self.depthwise = Conv2d(c, c , kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - - self.channelwise = nn.Sequential( - Linear(c + c, c ), - nn.GELU(), - GlobalResponseNorm(c ), - nn.Dropout(dropout), - Linear(c , c) - ) - self.use_checkpoint = use_checkpoint - - - def forward(self, hr, lr): - return checkpoint(self._forward, (hr, lr), self.parameters(), self.use_checkpoint) - - def _forward(self, hr, lr): - res = hr - hr = self.kv_mapper(rearrange(hr, 'b c h w -> b (h w ) c')) - lr_fuse = self.attention(self.norm(lr), hr, self_attn=False) + lr - - lr_fuse = torch.nn.functional.interpolate(lr_fuse.float(), res.shape[2:]) - #print('in line 65', lr_fuse.shape, res.shape) - media = torch.cat((self.depthwise(lr_fuse), res), dim=1) - out = self.channelwise(media.permute(0,2,3,1)).permute(0,3,1,2) + res - - return out - - - - -class Attention2D(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn: - #print('in line 23 algong self att ', kv.shape, x.shape) - - kv = torch.cat([x, kv], dim=1) - #if x.shape[1] > 48 * 48 and not self.training: - # x = x * math.sqrt(math.log(x.shape[1] , 24*24)) - - x = self.attn(x, kv, kv, need_weights=False)[0] - x = x.permute(0, 2, 1).view(*orig_shape) - return x -class Attention2D_splitpatch(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv, self_attn=False): - orig_shape = x.shape - - #x = rearrange(x, 'b c h w -> b c (nh wh) (nw ww)', wh=24, ww=24, nh=orig_shape[-2] // 24, nh=orig_shape[-1] // 24,) - x = rearrange(x, 'b c (nh wh) (nw ww) -> (b nh nw) (wh ww) c', wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24,) - #print('in line 168', x.shape) - #x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - if self_attn: - #print('in line 23 algong self att ', kv.shape, x.shape) - num = (orig_shape[-2] // 24) * (orig_shape[-1] // 24) - kv = torch.cat([x, kv.repeat(num, 1, 1)], dim=1) - #if x.shape[1] > 48 * 48 and not self.training: - # x = x * math.sqrt(math.log(x.shape[1] / math.sqrt(16), 24*24)) - - x = self.attn(x, kv, kv, need_weights=False)[0] - x = rearrange(x, ' (b nh nw) (wh ww) c -> b c (nh wh) (nw ww)', b=orig_shape[0], wh=24, ww=24, nh=orig_shape[-2] // 24, nw=orig_shape[-1] // 24) - #x = x.permute(0, 2, 1).view(*orig_shape) - - return x -class Attention2D_extra(nn.Module): - def __init__(self, c, nhead, dropout=0.0): - super().__init__() - self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) - - def forward(self, x, kv, extra_emb=None, self_attn=False): - orig_shape = x.shape - x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 - num_x = x.shape[1] - - - if extra_emb is not None: - ori_extra_shape = extra_emb.shape - extra_emb = extra_emb.view(extra_emb.size(0), extra_emb.size(1), -1).permute(0, 2, 1) - x = torch.cat((x, extra_emb), dim=1) - if self_attn: - #print('in line 23 algong self att ', kv.shape, x.shape) - kv = torch.cat([x, kv], dim=1) - x = self.attn(x, kv, kv, need_weights=False)[0] - img = x[:, :num_x, :].permute(0, 2, 1).view(*orig_shape) - if extra_emb is not None: - fix = x[:, num_x:, :].permute(0, 2, 1).view(*ori_extra_shape) - return img, fix - else: - return img -class AttnBlock_extraq(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - #self.norm2 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D_extra(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - # norm2 initialization in generator in init extra parameter - def forward(self, x, kv, extra_emb=None): - #print('in line 84', x.shape, kv.shape, self.self_attn, extra_emb if extra_emb is None else extra_emb.shape) - #in line 84 torch.Size([1, 1536, 32, 32]) torch.Size([1, 85, 1536]) True None - #if extra_emb is not None: - - kv = self.kv_mapper(kv) - if extra_emb is not None: - res_x, res_extra = self.attention(self.norm(x), kv, extra_emb=self.norm2(extra_emb), self_attn=self.self_attn) - x = x + res_x - extra_emb = extra_emb + res_extra - return x, extra_emb - else: - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x -class AttnBlock_latent2ex(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - - def forward(self, x, kv): - #print('in line 84', x.shape, kv.shape, self.self_attn) - kv = F.interpolate(kv.float(), x.shape[2:]) - kv = kv.view(kv.size(0), kv.size(1), -1).permute(0, 2, 1) - kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) -class AttnBlock_crossbranch(nn.Module): - def __init__(self, attnmodule, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.attn = AttnBlock(c, c_cond, nhead, self_attn, dropout) - #print('in line 108', attnmodule.device) - self.attn.load_state_dict(attnmodule.state_dict()) - self.norm1 = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - - self.channelwise1 = nn.Sequential( - Linear(c *2, c ), - nn.GELU(), - GlobalResponseNorm(c ), - nn.Dropout(dropout), - Linear(c, c) - ) - self.channelwise2 = nn.Sequential( - Linear(c *2, c ), - nn.GELU(), - GlobalResponseNorm(c ), - nn.Dropout(dropout), - Linear(c, c) - ) - self.c = c - def forward(self, x, kv, main_x): - #print('in line 84', x.shape, kv.shape, main_x.shape, self.c) - - x = self.channelwise1(torch.cat((x, F.interpolate(main_x.float(), x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x - x = self.attn(x, kv) - main_x = self.channelwise2(torch.cat((main_x, F.interpolate(x.float(), main_x.shape[2:])), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + main_x - return main_x, x - -class GlobalResponseNorm(nn.Module): - "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - -class ResBlock(nn.Module): - def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, use_checkpoint =True): # , num_heads=4, expansion=2): - super().__init__() - self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) - # self.depthwise = SAMBlock(c, num_heads, expansion) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - Linear(c + c_skip, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - Linear(c * 4, c) - ) - self.use_checkpoint = use_checkpoint - def forward(self, x, x_skip=None): - - if x_skip is not None: - return checkpoint(self._forward_skip, (x, x_skip), self.parameters(), self.use_checkpoint) - else: - #print('in line 298', x.shape) - return checkpoint(self._forward_woskip, (x, ), self.parameters(), self.use_checkpoint) - - - - def _forward_skip(self, x, x_skip): - x_res = x - x = self.norm(self.depthwise(x)) - if x_skip is not None: - x = torch.cat([x, x_skip], dim=1) - x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + x_res - def _forward_woskip(self, x): - x_res = x - x = self.norm(self.depthwise(x)) - - x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x + x_res - -class AttnBlock(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, use_checkpoint=True): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - Linear(c_cond, c) - ) - self.use_checkpoint = use_checkpoint - def forward(self, x, kv): - return checkpoint(self._forward, (x, kv), self.parameters(), self.use_checkpoint) - def _forward(self, x, kv): - kv = self.kv_mapper(kv) - res = self.attention(self.norm(x), kv, self_attn=self.self_attn) - - #print(torch.unique(res), torch.unique(x), self.self_attn) - #scale = math.sqrt(math.log(x.shape[-2] * x.shape[-1], 24*24)) - x = x + res - - return x -class AttnBlock_mytest(nn.Module): - def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): - super().__init__() - self.self_attn = self_attn - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.attention = Attention2D(c, nhead, dropout) - self.kv_mapper = nn.Sequential( - nn.SiLU(), - nn.Linear(c_cond, c) - ) - - def forward(self, x, kv): - kv = self.kv_mapper(kv) - x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) - return x - -class FeedForwardBlock(nn.Module): - def __init__(self, c, dropout=0.0): - super().__init__() - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - Linear(c, c * 4), - nn.GELU(), - GlobalResponseNorm(c * 4), - nn.Dropout(dropout), - Linear(c * 4, c) - ) - - def forward(self, x): - x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - return x - - -class TimestepBlock(nn.Module): - def __init__(self, c, c_timestep, conds=['sca'], use_checkpoint=True): - super().__init__() - self.mapper = Linear(c_timestep, c * 2) - self.conds = conds - for cname in conds: - setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) - - self.use_checkpoint = use_checkpoint - def forward(self, x, t): - return checkpoint(self._forward, (x, t), self.parameters(), self.use_checkpoint) - - def _forward(self, x, t): - #print('in line 284', x.shape, t.shape, self.conds) - #in line 284 torch.Size([4, 2048, 19, 29]) torch.Size([4, 192]) ['sca', 'crp'] - t = t.chunk(len(self.conds) + 1, dim=1) - a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) - for i, c in enumerate(self.conds): - ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) - a, b = a + ac, b + bc - return x * (1 + a) + b diff --git a/modules/controlnet.py b/modules/controlnet.py deleted file mode 100644 index c187aecb725e00e19924ae308e3aac401acfdf06..0000000000000000000000000000000000000000 --- a/modules/controlnet.py +++ /dev/null @@ -1,349 +0,0 @@ -import torchvision -import torch -from torch import nn -import numpy as np -import kornia -import cv2 -from core.utils import load_or_fail -#from insightface.app.common import Face -from .effnet import EfficientNetEncoder -from .cnet_modules.pidinet import PidiNetDetector -from .cnet_modules.inpainting.saliency_model import MicroResNet -#from .cnet_modules.face_id.arcface import FaceDetector, ArcFaceRecognizer -from .common import LayerNorm2d - - -class CNetResBlock(nn.Module): - def __init__(self, c): - super().__init__() - self.blocks = nn.Sequential( - LayerNorm2d(c), - nn.GELU(), - nn.Conv2d(c, c, kernel_size=3, padding=1), - LayerNorm2d(c), - nn.GELU(), - nn.Conv2d(c, c, kernel_size=3, padding=1), - ) - - def forward(self, x): - return x + self.blocks(x) - - -class ControlNet(nn.Module): - def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None): - super().__init__() - if bottleneck_mode is None: - bottleneck_mode = 'effnet' - self.proj_blocks = proj_blocks - if bottleneck_mode == 'effnet': - embd_channels = 1280 - #self.backbone = torchvision.models.efficientnet_v2_s(weights='DEFAULT').features.eval() - self.backbone = torchvision.models.efficientnet_v2_s().features.eval() - if c_in != 3: - in_weights = self.backbone[0][0].weight.data - self.backbone[0][0] = nn.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False) - if c_in > 3: - nn.init.constant_(self.backbone[0][0].weight, 0) - self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() - else: - self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() - elif bottleneck_mode == 'simple': - embd_channels = c_in - self.backbone = nn.Sequential( - nn.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1), - ) - elif bottleneck_mode == 'large': - self.backbone = nn.Sequential( - nn.Conv2d(c_in, 4096 * 4, kernel_size=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(4096 * 4, 1024, kernel_size=1), - *[CNetResBlock(1024) for _ in range(8)], - nn.Conv2d(1024, 1280, kernel_size=1), - ) - embd_channels = 1280 - else: - raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') - self.projections = nn.ModuleList() - for _ in range(len(proj_blocks)): - self.projections.append(nn.Sequential( - nn.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False), - )) - nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection - - def forward(self, x): - x = self.backbone(x) - proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] - for i, idx in enumerate(self.proj_blocks): - proj_outputs[idx] = self.projections[i](x) - return proj_outputs - - -class ControlNetDeliverer(): - def __init__(self, controlnet_projections): - self.controlnet_projections = controlnet_projections - self.restart() - - def restart(self): - self.idx = 0 - return self - - def __call__(self): - if self.idx < len(self.controlnet_projections): - output = self.controlnet_projections[self.idx] - else: - output = None - self.idx += 1 - return output - - -# CONTROLNET FILTERS ---------------------------------------------------- - -class BaseFilter(): - def __init__(self, device): - self.device = device - - def num_channels(self): - return 3 - - def __call__(self, x): - return x - - -class CannyFilter(BaseFilter): - def __init__(self, device, resize=224): - super().__init__(device) - self.resize = resize - - def num_channels(self): - return 1 - - def __call__(self, x): - orig_size = x.shape[-2:] - if self.resize is not None: - x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') - edges = [cv2.Canny(x[i].mul(255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), 100, 200) for i in range(len(x))] - edges = torch.stack([torch.tensor(e).div(255).unsqueeze(0) for e in edges], dim=0) - if self.resize is not None: - edges = nn.functional.interpolate(edges, size=orig_size, mode='bilinear') - return edges - - -class QRFilter(BaseFilter): - def __init__(self, device, resize=224, blobify=True, dilation_kernels=[3, 5, 7], blur_kernels=[15]): - super().__init__(device) - self.resize = resize - self.blobify = blobify - self.dilation_kernels = dilation_kernels - self.blur_kernels = blur_kernels - - def num_channels(self): - return 1 - - def __call__(self, x): - x = x.to(self.device) - orig_size = x.shape[-2:] - if self.resize is not None: - x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') - - x = kornia.color.rgb_to_hsv(x)[:, -1:] - # blobify - if self.blobify: - d_kernel = np.random.choice(self.dilation_kernels) - d_blur = np.random.choice(self.blur_kernels) - if d_blur > 0: - x = torchvision.transforms.GaussianBlur(d_blur)(x) - if d_kernel > 0: - blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, - d_kernel).pow(2)[:, - None]) < 0.3).float().to(self.device) - x = kornia.morphology.dilation(x, blob_mask) - x = kornia.morphology.erosion(x, blob_mask) - # mask - vmax, vmin = x.amax(dim=[2, 3], keepdim=True)[0], x.amin(dim=[2, 3], keepdim=True)[0] - th = (vmax - vmin) * 0.33 - high_brightness, low_brightness = (x > (vmax - th)).float(), (x < (vmin + th)).float() - mask = (torch.ones_like(x) - low_brightness + high_brightness) * 0.5 - - if self.resize is not None: - mask = nn.functional.interpolate(mask, size=orig_size, mode='bilinear') - return mask.cpu() - - -class PidiFilter(BaseFilter): - def __init__(self, device, resize=224, dilation_kernels=[0, 3, 5, 7, 9], binarize=True): - super().__init__(device) - self.resize = resize - self.model = PidiNetDetector(device) - self.dilation_kernels = dilation_kernels - self.binarize = binarize - - def num_channels(self): - return 1 - - def __call__(self, x): - x = x.to(self.device) - orig_size = x.shape[-2:] - if self.resize is not None: - x = nn.functional.interpolate(x, size=(self.resize, self.resize), mode='bilinear') - - x = self.model(x) - d_kernel = np.random.choice(self.dilation_kernels) - if d_kernel > 0: - blob_mask = ((torch.linspace(-0.5, 0.5, d_kernel).pow(2)[None] + torch.linspace(-0.5, 0.5, d_kernel).pow(2)[ - :, None]) < 0.3).float().to(self.device) - x = kornia.morphology.dilation(x, blob_mask) - if self.binarize: - th = np.random.uniform(0.05, 0.7) - x = (x > th).float() - - if self.resize is not None: - x = nn.functional.interpolate(x, size=orig_size, mode='bilinear') - return x.cpu() - - -class SRFilter(BaseFilter): - def __init__(self, device, scale_factor=1 / 4): - super().__init__(device) - self.scale_factor = scale_factor - - def num_channels(self): - return 3 - - def __call__(self, x): - x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") - return torch.nn.functional.interpolate(x, scale_factor=1 / self.scale_factor, mode="nearest") - - -class SREffnetFilter(BaseFilter): - def __init__(self, device, scale_factor=1/2): - super().__init__(device) - self.scale_factor = scale_factor - - self.effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - self.effnet = EfficientNetEncoder().to(self.device) - effnet_checkpoint = load_or_fail("models/effnet_encoder.safetensors") - self.effnet.load_state_dict(effnet_checkpoint) - self.effnet.eval().requires_grad_(False) - - def num_channels(self): - return 16 - - def __call__(self, x): - x = torch.nn.functional.interpolate(x.clone(), scale_factor=self.scale_factor, mode="nearest") - with torch.no_grad(): - effnet_embedding = self.effnet(self.effnet_preprocess(x.to(self.device))).cpu() - effnet_embedding = torch.nn.functional.interpolate(effnet_embedding, scale_factor=1/self.scale_factor, mode="nearest") - upscaled_image = torch.nn.functional.interpolate(x, scale_factor=1/self.scale_factor, mode="nearest") - return effnet_embedding, upscaled_image - - -class InpaintFilter(BaseFilter): - def __init__(self, device, thresold=[0.04, 0.4], p_outpaint=0.4): - super().__init__(device) - self.saliency_model = MicroResNet().eval().requires_grad_(False).to(device) - self.saliency_model.load_state_dict(load_or_fail("modules/cnet_modules/inpainting/saliency_model.pt")) - self.thresold = thresold - self.p_outpaint = p_outpaint - - def num_channels(self): - return 4 - - def __call__(self, x, mask=None, threshold=None, outpaint=None): - x = x.to(self.device) - resized_x = torchvision.transforms.functional.resize(x, 240, antialias=True) - if threshold is None: - threshold = np.random.uniform(self.thresold[0], self.thresold[1]) - if mask is None: - saliency_map = self.saliency_model(resized_x) > threshold - if outpaint is None: - if np.random.rand() < self.p_outpaint: - saliency_map = ~saliency_map - else: - if outpaint: - saliency_map = ~saliency_map - interpolated_saliency_map = torch.nn.functional.interpolate(saliency_map.float(), size=x.shape[2:], mode="nearest") - saliency_map = torchvision.transforms.functional.gaussian_blur(interpolated_saliency_map, 141) > 0.5 - inpainted_images = torch.where(saliency_map, torch.ones_like(x), x) - mask = torch.nn.functional.interpolate(saliency_map.float(), size=inpainted_images.shape[2:], mode="nearest") - else: - mask = mask.to(self.device) - inpainted_images = torch.where(mask, torch.ones_like(x), x) - c_inpaint = torch.cat([inpainted_images, mask], dim=1) - return c_inpaint.cpu() - - -# IDENTITY -''' -class IdentityFilter(BaseFilter): - def __init__(self, device, max_faces=4, p_drop=0.05, p_full=0.3): - detector_path = 'modules/cnet_modules/face_id/models/buffalo_l/det_10g.onnx' - recognizer_path = 'modules/cnet_modules/face_id/models/buffalo_l/w600k_r50.onnx' - - super().__init__(device) - self.max_faces = max_faces - self.p_drop = p_drop - self.p_full = p_full - - self.detector = FaceDetector(detector_path, device=device) - self.recognizer = ArcFaceRecognizer(recognizer_path, device=device) - - self.id_colors = torch.tensor([ - [1.0, 0.0, 0.0], # RED - [0.0, 1.0, 0.0], # GREEN - [0.0, 0.0, 1.0], # BLUE - [1.0, 0.0, 1.0], # PURPLE - [0.0, 1.0, 1.0], # CYAN - [1.0, 1.0, 0.0], # YELLOW - [0.5, 0.0, 0.0], # DARK RED - [0.0, 0.5, 0.0], # DARK GREEN - [0.0, 0.0, 0.5], # DARK BLUE - [0.5, 0.0, 0.5], # DARK PURPLE - [0.0, 0.5, 0.5], # DARK CYAN - [0.5, 0.5, 0.0], # DARK YELLOW - ]) - - def num_channels(self): - return 512 - - def get_faces(self, image): - npimg = image.permute(1, 2, 0).mul(255).to(device="cpu", dtype=torch.uint8).cpu().numpy() - bgr = cv2.cvtColor(npimg, cv2.COLOR_RGB2BGR) - bboxes, kpss = self.detector.detect(bgr, max_num=self.max_faces) - N = len(bboxes) - ids = torch.zeros((N, 512), dtype=torch.float32) - for i in range(N): - face = Face(bbox=bboxes[i, :4], kps=kpss[i], det_score=bboxes[i, 4]) - ids[i, :] = self.recognizer.get(bgr, face) - tbboxes = torch.tensor(bboxes[:, :4], dtype=torch.int) - - ids = ids / torch.linalg.norm(ids, dim=1, keepdim=True) - return tbboxes, ids # returns bounding boxes (N x 4) and ID vectors (N x 512) - - def __call__(self, x): - visual_aid = x.clone().cpu() - face_mtx = torch.zeros(x.size(0), 512, x.size(-2) // 32, x.size(-1) // 32) - - for i in range(x.size(0)): - bounding_boxes, ids = self.get_faces(x[i]) - for j in range(bounding_boxes.size(0)): - if np.random.rand() > self.p_drop: - sx, sy, ex, ey = (bounding_boxes[j] / 32).clamp(min=0).round().int().tolist() - ex, ey = max(ex, sx + 1), max(ey, sy + 1) - if bounding_boxes.size(0) == 1 and np.random.rand() < self.p_full: - sx, sy, ex, ey = 0, 0, x.size(-1) // 32, x.size(-2) // 32 - face_mtx[i, :, sy:ey, sx:ex] = ids[j:j + 1, :, None, None] - visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] += self.id_colors[j % 13, :, - None, None] - visual_aid[i, :, int(sy * 32):int(ey * 32), int(sx * 32):int(ex * 32)] *= 0.5 - - return face_mtx.to(x.device), visual_aid.to(x.device) -''' diff --git a/modules/effnet.py b/modules/effnet.py deleted file mode 100644 index 0eb2690c2547c8c7553aec8a9f9e838241f8f61c..0000000000000000000000000000000000000000 --- a/modules/effnet.py +++ /dev/null @@ -1,17 +0,0 @@ -import torchvision -from torch import nn - - -# EfficientNet -class EfficientNetEncoder(nn.Module): - def __init__(self, c_latent=16): - super().__init__() - self.backbone = torchvision.models.efficientnet_v2_s().features.eval() - self.mapper = nn.Sequential( - nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 - ) - - def forward(self, x): - return self.mapper(self.backbone(x)) - diff --git a/modules/inr_fea_res_lite.py b/modules/inr_fea_res_lite.py deleted file mode 100644 index 41ddfb09937f26e2c7d0193b4a65607efabde5e5..0000000000000000000000000000000000000000 --- a/modules/inr_fea_res_lite.py +++ /dev/null @@ -1,435 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -import einops -import numpy as np -import models -from modules.common_ckpt import Linear, Conv2d, AttnBlock, ResBlock, LayerNorm2d -#from modules.common_ckpt import AttnBlock, -from einops import rearrange -import torch.fft as fft -from modules.speed_util import checkpoint -def batched_linear_mm(x, wb): - # x: (B, N, D1); wb: (B, D1 + 1, D2) or (D1 + 1, D2) - one = torch.ones(*x.shape[:-1], 1, device=x.device) - return torch.matmul(torch.cat([x, one], dim=-1), wb) -def make_coord_grid(shape, range, device=None): - """ - Args: - shape: tuple - range: [minv, maxv] or [[minv_1, maxv_1], ..., [minv_d, maxv_d]] for each dim - Returns: - grid: shape (*shape, ) - """ - l_lst = [] - for i, s in enumerate(shape): - l = (0.5 + torch.arange(s, device=device)) / s - if isinstance(range[0], list) or isinstance(range[0], tuple): - minv, maxv = range[i] - else: - minv, maxv = range - l = minv + (maxv - minv) * l - l_lst.append(l) - grid = torch.meshgrid(*l_lst, indexing='ij') - grid = torch.stack(grid, dim=-1) - return grid -def init_wb(shape): - weight = torch.empty(shape[1], shape[0] - 1) - nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - - bias = torch.empty(shape[1], 1) - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - nn.init.uniform_(bias, -bound, bound) - - return torch.cat([weight, bias], dim=1).t().detach() - -def init_wb_rewrite(shape): - weight = torch.empty(shape[1], shape[0] - 1) - - torch.nn.init.xavier_uniform_(weight) - - bias = torch.empty(shape[1], 1) - torch.nn.init.xavier_uniform_(bias) - - - return torch.cat([weight, bias], dim=1).t().detach() -class HypoMlp(nn.Module): - - def __init__(self, depth, in_dim, out_dim, hidden_dim, use_pe, pe_dim, out_bias=0, pe_sigma=1024): - super().__init__() - self.use_pe = use_pe - self.pe_dim = pe_dim - self.pe_sigma = pe_sigma - self.depth = depth - self.param_shapes = dict() - if use_pe: - last_dim = in_dim * pe_dim - else: - last_dim = in_dim - for i in range(depth): # for each layer the weight - cur_dim = hidden_dim if i < depth - 1 else out_dim - self.param_shapes[f'wb{i}'] = (last_dim + 1, cur_dim) - last_dim = cur_dim - self.relu = nn.ReLU() - self.params = None - self.out_bias = out_bias - - def set_params(self, params): - self.params = params - - def convert_posenc(self, x): - w = torch.exp(torch.linspace(0, np.log(self.pe_sigma), self.pe_dim // 2, device=x.device)) - x = torch.matmul(x.unsqueeze(-1), w.unsqueeze(0)).view(*x.shape[:-1], -1) - x = torch.cat([torch.cos(np.pi * x), torch.sin(np.pi * x)], dim=-1) - return x - - def forward(self, x): - B, query_shape = x.shape[0], x.shape[1: -1] - x = x.view(B, -1, x.shape[-1]) - if self.use_pe: - x = self.convert_posenc(x) - #print('in line 79 after pos embedding', x.shape) - for i in range(self.depth): - x = batched_linear_mm(x, self.params[f'wb{i}']) - if i < self.depth - 1: - x = self.relu(x) - else: - x = x + self.out_bias - x = x.view(B, *query_shape, -1) - return x - - - -class Attention(nn.Module): - - def __init__(self, dim, n_head, head_dim, dropout=0.): - super().__init__() - self.n_head = n_head - inner_dim = n_head * head_dim - self.to_q = nn.Sequential( - nn.SiLU(), - Linear(dim, inner_dim )) - self.to_kv = nn.Sequential( - nn.SiLU(), - Linear(dim, inner_dim * 2)) - self.scale = head_dim ** -0.5 - # self.to_out = nn.Sequential( - # Linear(inner_dim, dim), - # nn.Dropout(dropout), - # ) - - def forward(self, fr, to=None): - if to is None: - to = fr - q = self.to_q(fr) - k, v = self.to_kv(to).chunk(2, dim=-1) - q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> b h n d', h=self.n_head), [q, k, v]) - - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - attn = F.softmax(dots, dim=-1) # b h n n - out = torch.matmul(attn, v) - out = einops.rearrange(out, 'b h n d -> b n (h d)') - return out - - -class FeedForward(nn.Module): - - def __init__(self, dim, ff_dim, dropout=0.): - super().__init__() - - self.net = nn.Sequential( - Linear(dim, ff_dim), - nn.GELU(), - #GlobalResponseNorm(ff_dim), - nn.Dropout(dropout), - Linear(ff_dim, dim) - ) - - def forward(self, x): - return self.net(x) - - -class PreNorm(nn.Module): - - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x): - return self.fn(self.norm(x)) - - -#TransInr(ind=2048, ch=256, n_head=16, head_dim=16, n_groups=64, f_dim=256, time_dim=self.c_r, t_conds = []) -class TransformerEncoder(nn.Module): - - def __init__(self, dim, depth, n_head, head_dim, ff_dim, dropout=0.): - super().__init__() - self.layers = nn.ModuleList() - for _ in range(depth): - self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, n_head, head_dim, dropout=dropout)), - PreNorm(dim, FeedForward(dim, ff_dim, dropout=dropout)), - ])) - - def forward(self, x): - for norm_attn, norm_ff in self.layers: - x = x + norm_attn(x) - x = x + norm_ff(x) - return x -class ImgrecTokenizer(nn.Module): - - def __init__(self, input_size=32*32, patch_size=1, dim=768, padding=0, img_channels=16): - super().__init__() - - if isinstance(patch_size, int): - patch_size = (patch_size, patch_size) - if isinstance(padding, int): - padding = (padding, padding) - self.patch_size = patch_size - self.padding = padding - self.prefc = nn.Linear(patch_size[0] * patch_size[1] * img_channels, dim) - - self.posemb = nn.Parameter(torch.randn(input_size, dim)) - - def forward(self, x): - #print(x.shape) - p = self.patch_size - x = F.unfold(x, p, stride=p, padding=self.padding) # (B, C * p * p, L) - #print('in line 185 after unfoding', x.shape) - x = x.permute(0, 2, 1).contiguous() - ttt = self.prefc(x) - - x = self.prefc(x) + self.posemb[:x.shape[1]].unsqueeze(0) - return x - -class SpatialAttention(nn.Module): - def __init__(self, kernel_size=7): - super(SpatialAttention, self).__init__() - - self.conv1 = Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - avg_out = torch.mean(x, dim=1, keepdim=True) - max_out, _ = torch.max(x, dim=1, keepdim=True) - x = torch.cat([avg_out, max_out], dim=1) - x = self.conv1(x) - return self.sigmoid(x) - -class TimestepBlock_res(nn.Module): - def __init__(self, c, c_timestep, conds=['sca']): - super().__init__() - - self.mapper = Linear(c_timestep, c * 2) - self.conds = conds - for cname in conds: - setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) - - - - - def forward(self, x, t): - #print(x.shape, t.shape, self.conds, 'in line 269') - t = t.chunk(len(self.conds) + 1, dim=1) - a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) - - for i, c in enumerate(self.conds): - ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) - a, b = a + ac, b + bc - return x * (1 + a) + b - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - - -class ScaleNormalize_res(nn.Module): - def __init__(self, c, scale_c, conds=['sca']): - super().__init__() - self.c_r = scale_c - self.mapping = TimestepBlock_res(c, scale_c, conds=conds) - self.t_conds = conds - self.alpha = nn.Conv2d(c, c, kernel_size=1) - self.gamma = nn.Conv2d(c, c, kernel_size=1) - self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) - - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode='constant') - return emb - def forward(self, x, std_size=24*24): - scale_val = math.sqrt(math.log(x.shape[-2] * x.shape[-1], std_size)) - scale_val = torch.ones(x.shape[0]).to(x.device)*scale_val - scale_val_f = self.gen_r_embedding(scale_val) - for c in self.t_conds: - t_cond = torch.zeros_like(scale_val) - scale_val_f = torch.cat([scale_val_f, self.gen_r_embedding(t_cond)], dim=1) - - f = self.mapping(x, scale_val_f) - - return f + x - - -class TransInr_withnorm(nn.Module): - - def __init__(self, ind=2048, ch=16, n_head=12, head_dim=64, n_groups=64, f_dim=768, time_dim=2048, t_conds=[]): - super().__init__() - self.input_layer= nn.Conv2d(ind, ch, 1) - self.tokenizer = ImgrecTokenizer(dim=ch, img_channels=ch) - #self.hyponet = HypoMlp(depth=12, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) - #self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=12, n_head=n_head, head_dim=f_dim // n_head, ff_dim=3*f_dim, ) - - self.hyponet = HypoMlp(depth=2, in_dim=2, out_dim=ch, hidden_dim=f_dim, use_pe=True, pe_dim=128) - self.transformer_encoder = TransformerEncoder(dim=f_dim, depth=1, n_head=n_head, head_dim=f_dim // n_head, ff_dim=f_dim) - #self.transformer_encoder = TransInr( ch=ch, n_head=16, head_dim=16, n_groups=64, f_dim=ch, time_dim=time_dim, t_conds = []) - self.base_params = nn.ParameterDict() - n_wtokens = 0 - self.wtoken_postfc = nn.ModuleDict() - self.wtoken_rng = dict() - for name, shape in self.hyponet.param_shapes.items(): - self.base_params[name] = nn.Parameter(init_wb(shape)) - g = min(n_groups, shape[1]) - assert shape[1] % g == 0 - self.wtoken_postfc[name] = nn.Sequential( - nn.LayerNorm(f_dim), - nn.Linear(f_dim, shape[0] - 1), - ) - self.wtoken_rng[name] = (n_wtokens, n_wtokens + g) - n_wtokens += g - self.wtokens = nn.Parameter(torch.randn(n_wtokens, f_dim)) - self.output_layer= nn.Conv2d(ch, ind, 1) - - - self.mapp_t = TimestepBlock_res( ind, time_dim, conds = t_conds) - - - self.hr_norm = ScaleNormalize_res(ind, 64, conds=[]) - - self.normalize_final = nn.Sequential( - LayerNorm2d(ind, elementwise_affine=False, eps=1e-6), - ) - - self.toout = nn.Sequential( - Linear( ind*2, ind // 4), - nn.GELU(), - Linear( ind // 4, ind) - ) - self.apply(self._init_weights) - - mask = torch.zeros((1, 1, 32, 32)) - h, w = 32, 32 - center_h, center_w = h // 2, w // 2 - low_freq_h, low_freq_w = h // 4, w // 4 - mask[:, :, center_h-low_freq_h:center_h+low_freq_h, center_w-low_freq_w:center_w+low_freq_w] = 1 - self.mask = mask - - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - torch.nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - #nn.init.constant_(self.last.weight, 0) - def adain(self, feature_a, feature_b): - norm_mean = torch.mean(feature_a, dim=(2, 3), keepdim=True) - norm_std = torch.std(feature_a, dim=(2, 3), keepdim=True) - #feature_a = F.interpolate(feature_a, feature_b.shape[2:]) - feature_b = (feature_b - feature_b.mean(dim=(2, 3), keepdim=True)) / (1e-8 + feature_b.std(dim=(2, 3), keepdim=True)) * norm_std + norm_mean - return feature_b - def forward(self, target_shape, target, dtokens, t_emb): - #print(target.shape, dtokens.shape, 'in line 290') - hlr, wlr = dtokens.shape[2:] - original = dtokens - - dtokens = self.input_layer(dtokens) - dtokens = self.tokenizer(dtokens) - B = dtokens.shape[0] - wtokens = einops.repeat(self.wtokens, 'n d -> b n d', b=B) - #print(wtokens.shape, dtokens.shape) - trans_out = self.transformer_encoder(torch.cat([dtokens, wtokens], dim=1)) - trans_out = trans_out[:, -len(self.wtokens):, :] - - params = dict() - for name, shape in self.hyponet.param_shapes.items(): - wb = einops.repeat(self.base_params[name], 'n m -> b n m', b=B) - w, b = wb[:, :-1, :], wb[:, -1:, :] - - l, r = self.wtoken_rng[name] - x = self.wtoken_postfc[name](trans_out[:, l: r, :]) - x = x.transpose(-1, -2) # (B, shape[0] - 1, g) - w = F.normalize(w * x.repeat(1, 1, w.shape[2] // x.shape[2]), dim=1) - - wb = torch.cat([w, b], dim=1) - params[name] = wb - coord = make_coord_grid(target_shape[2:], (-1, 1), device=dtokens.device) - coord = einops.repeat(coord, 'h w d -> b h w d', b=dtokens.shape[0]) - self.hyponet.set_params(params) - ori_up = F.interpolate(original.float(), target_shape[2:]) - hr_rec = self.output_layer(rearrange(self.hyponet(coord), 'b h w c -> b c h w')) + ori_up - #print(hr_rec.shape, target.shape, torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1).shape, 'in line 537') - - output = self.toout(torch.cat((hr_rec, target), dim=1).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - #print(output.shape, 'in line 540') - #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)* 0.3 - output = self.mapp_t(output, t_emb) - output = self.normalize_final(output) - output = self.hr_norm(output) - #output = self.last(output.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - #output = self.mapp_t(output, t_emb) - #output = self.weight(output) * output - - return output - - - - - - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x): - return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - -class GlobalResponseNorm(nn.Module): - "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - - - -if __name__ == '__main__': - #ef __init__(self, ch, n_head, head_dim, n_groups): - trans_inr = TransInr(16, 24, 32, 64).cuda() - input = torch.randn((1, 16, 24, 24)).cuda() - source = torch.randn((1, 16, 16, 16)).cuda() - t = torch.randn((1, 128)).cuda() - output, hr = trans_inr(input, t, source) - - total_up = sum([ param.nelement() for param in trans_inr.parameters()]) - print(output.shape, hr.shape, total_up /1e6 ) - diff --git a/modules/lora.py b/modules/lora.py deleted file mode 100644 index bc0a2bd797f3669a465f6c2c4255b52fe1bda7a7..0000000000000000000000000000000000000000 --- a/modules/lora.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -from torch import nn - - -class LoRA(nn.Module): - def __init__(self, layer, name='weight', rank=16, alpha=1): - super().__init__() - weight = getattr(layer, name) - self.lora_down = nn.Parameter(torch.zeros((rank, weight.size(1)))) - self.lora_up = nn.Parameter(torch.zeros((weight.size(0), rank))) - nn.init.normal_(self.lora_up, mean=0, std=1) - - self.scale = alpha / rank - self.enabled = True - - def forward(self, original_weights): - if self.enabled: - lora_shape = list(original_weights.shape[:2]) + [1] * (len(original_weights.shape) - 2) - lora_weights = torch.matmul(self.lora_up.clone(), self.lora_down.clone()).view(*lora_shape) * self.scale - return original_weights + lora_weights - else: - return original_weights - - -def apply_lora(model, filters=None, rank=16): - def check_parameter(module, name): - return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( - getattr(module, name), nn.Parameter) - - for name, module in model.named_modules(): - if filters is None or any([f in name for f in filters]): - if check_parameter(module, "weight"): - device, dtype = module.weight.device, module.weight.dtype - torch.nn.utils.parametrize.register_parametrization(module, 'weight', LoRA(module, "weight", rank=rank).to(dtype).to(device)) - elif check_parameter(module, "in_proj_weight"): - device, dtype = module.in_proj_weight.device, module.in_proj_weight.dtype - torch.nn.utils.parametrize.register_parametrization(module, 'in_proj_weight', LoRA(module, "in_proj_weight", rank=rank).to(dtype).to(device)) - - -class ReToken(nn.Module): - def __init__(self, indices=None): - super().__init__() - assert indices is not None - self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280)) - self.register_buffer('indices', torch.tensor(indices)) - self.enabled = True - - def forward(self, embeddings): - if self.enabled: - embeddings = embeddings.clone() - for i, idx in enumerate(self.indices): - embeddings[idx] += self.embeddings[i] - return embeddings - - -def apply_retoken(module, indices=None): - def check_parameter(module, name): - return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance( - getattr(module, name), nn.Parameter) - - if check_parameter(module, "weight"): - device, dtype = module.weight.device, module.weight.dtype - torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices).to(dtype).to(device)) - - -def remove_lora(model, leave_parametrized=True): - for module in model.modules(): - if torch.nn.utils.parametrize.is_parametrized(module, "weight"): - nn.utils.parametrize.remove_parametrizations(module, "weight", leave_parametrized=leave_parametrized) - elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): - nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight", leave_parametrized=leave_parametrized) diff --git a/modules/model_4stage_lite.py b/modules/model_4stage_lite.py deleted file mode 100644 index e77cc5d73ccda882774f447f5a8bb86fe71fe755..0000000000000000000000000000000000000000 --- a/modules/model_4stage_lite.py +++ /dev/null @@ -1,458 +0,0 @@ -import torch -from torch import nn -import numpy as np -import math -from modules.common_ckpt import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock -from .controlnet import ControlNetDeliverer -import torch.nn.functional as F -from modules.inr_fea_res_lite import TransInr_withnorm as TransInr -from modules.inr_fea_res_lite import ScaleNormalize_res -from einops import rearrange -import torch.fft as fft -import random -class UpDownBlock2d(nn.Module): - def __init__(self, c_in, c_out, mode, enabled=True): - super().__init__() - assert mode in ['up', 'down'] - interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', - align_corners=True) if enabled else nn.Identity() - mapping = nn.Conv2d(c_in, c_out, kernel_size=1) - self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) - - def forward(self, x): - for block in self.blocks: - x = block(x.float()) - return x -def ada_in(a, b): - mean_a = torch.mean(a, dim=(2, 3), keepdim=True) - std_a = torch.std(a, dim=(2, 3), keepdim=True) - - mean_b = torch.mean(b, dim=(2, 3), keepdim=True) - std_b = torch.std(b, dim=(2, 3), keepdim=True) - - return (b - mean_b) / (1e-8 + std_b) * std_a + mean_a -def feature_dist_loss(x1, x2): - mu1 = torch.mean(x1, dim=(2, 3)) - mu2 = torch.mean(x2, dim=(2, 3)) - - std1 = torch.std(x1, dim=(2, 3)) - std2 = torch.std(x2, dim=(2, 3)) - std_loss = torch.mean(torch.abs(torch.log(std1+ 1e-8) - torch.log(std2+ 1e-8))) - mean_loss = torch.mean(torch.abs(mu1 - mu2)) - #print('in line 36', std_loss, mean_loss) - return std_loss + mean_loss*0.1 -class StageC(nn.Module): - def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], - blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], - c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, - dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], - lr_h=24, lr_w=24): - super().__init__() - - self.lr_h, self.lr_w = lr_h, lr_w - self.block_repeat = block_repeat - self.c_in = c_in - self.c_cond = c_cond - self.patch_size = patch_size - self.c_hidden = c_hidden - self.nhead = nhead - self.blocks = blocks - self.level_config = level_config - self.kernel_size = kernel_size - self.c_r = c_r - self.t_conds = t_conds - self.c_clip_seq = c_clip_seq - if not isinstance(dropout, list): - dropout = [dropout] * len(c_hidden) - if not isinstance(self_attn, list): - self_attn = [self_attn] * len(c_hidden) - self.self_attn = self_attn - self.dropout = dropout - self.switch_level = switch_level - # CONDITIONING - self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) - self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) - self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) - self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) - - self.embedding = nn.Sequential( - nn.PixelUnshuffle(patch_size), - nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) - ) - - def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): - if block_type == 'C': - return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) - elif block_type == 'A': - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) - elif block_type == 'F': - return FeedForwardBlock(c_hidden, dropout=dropout) - elif block_type == 'T': - return TimestepBlock(c_hidden, c_r, conds=t_conds) - else: - raise Exception(f'Block type {block_type} not supported') - - # BLOCKS - # -- down blocks - self.down_blocks = nn.ModuleList() - self.down_downscalers = nn.ModuleList() - self.down_repeat_mappers = nn.ModuleList() - for i in range(len(c_hidden)): - if i > 0: - self.down_downscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), - UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) - )) - else: - self.down_downscalers.append(nn.Identity()) - down_block = nn.ModuleList() - for _ in range(blocks[0][i]): - for block_type in level_config[i]: - block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) - down_block.append(block) - self.down_blocks.append(down_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[0][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.down_repeat_mappers.append(block_repeat_mappers) - - - - #extra down blocks - - - # -- up blocks - self.up_blocks = nn.ModuleList() - self.up_upscalers = nn.ModuleList() - self.up_repeat_mappers = nn.ModuleList() - for i in reversed(range(len(c_hidden))): - if i > 0: - self.up_upscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), - UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) - )) - else: - self.up_upscalers.append(nn.Identity()) - up_block = nn.ModuleList() - for j in range(blocks[1][::-1][i]): - for k, block_type in enumerate(level_config[i]): - c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 - block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], - self_attn=self_attn[i]) - up_block.append(block) - self.up_blocks.append(up_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[1][::-1][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.up_repeat_mappers.append(block_repeat_mappers) - - # OUTPUT - self.clf = nn.Sequential( - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), - nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), - nn.PixelShuffle(patch_size), - ) - - # --- WEIGHT INIT --- - self.apply(self._init_weights) # General init - nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings - nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings - nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings - torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs - nn.init.constant_(self.clf[1].weight, 0) # outputs - - # blocks - for level_block in self.down_blocks + self.up_blocks: - for block in level_block: - if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): - block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) - elif isinstance(block, TimestepBlock): - for layer in block.modules(): - if isinstance(layer, nn.Linear): - nn.init.constant_(layer.weight, 0) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - torch.nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - - def _init_extra_parameter(self): - - - - self.agg_net = nn.ModuleList() - for _ in range(2): - - self.agg_net.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # - - self.agg_net_up = nn.ModuleList() - for _ in range(2): - - self.agg_net_up.append(TransInr(ind=2048, ch=1024, n_head=32, head_dim=32, n_groups=64, f_dim=1024, time_dim=self.c_r, t_conds = [])) # - - - - - - self.norm_down_blocks = nn.ModuleList() - for i in range(len(self.c_hidden)): - - up_blocks = nn.ModuleList() - for j in range(self.blocks[0][i]): - if j % 4 == 0: - up_blocks.append( - ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) - self.norm_down_blocks.append(up_blocks) - - - self.norm_up_blocks = nn.ModuleList() - for i in reversed(range(len(self.c_hidden))): - - up_block = nn.ModuleList() - for j in range(self.blocks[1][::-1][i]): - if j % 4 == 0: - up_block.append(ScaleNormalize_res(self.c_hidden[0], self.c_r, conds=[])) - self.norm_up_blocks.append(up_block) - - - - - self.agg_net.apply(self._init_weights) - self.agg_net_up.apply(self._init_weights) - self.norm_up_blocks.apply(self._init_weights) - self.norm_down_blocks.apply(self._init_weights) - for block in self.agg_net + self.agg_net_up: - #for block in level_block: - if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): - block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) - elif isinstance(block, TimestepBlock): - for layer in block.modules(): - if isinstance(layer, nn.Linear): - nn.init.constant_(layer.weight, 0) - - - - - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode='constant') - return emb - - def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): - clip_txt = self.clip_txt_mapper(clip_txt) - if len(clip_txt_pooled.shape) == 2: - clip_txt_pool = clip_txt_pooled.unsqueeze(1) - if len(clip_img.shape) == 2: - clip_img = clip_img.unsqueeze(1) - clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) - clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) - clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) - clip = self.clip_norm(clip) - return clip - - def _down_encode(self, x, r_embed, clip, cnet=None, require_q=False, lr_guide=None, r_emb_lite=None, guide_weight=1): - level_outputs = [] - if require_q: - qs = [] - block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) - for stage_cnt, (down_block, downscaler, repmap) in enumerate(block_group): - x = downscaler(x) - for i in range(len(repmap) + 1): - for inner_cnt, block in enumerate(down_block): - - - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - if cnet is not None and lr_guide is None: - #if cnet is not None : - next_cnet = cnet() - if next_cnet is not None: - - x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', - align_corners=True) - x = block(x) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - - x = block(x, clip) - if require_q and (inner_cnt == 2 ): - qs.append(x.clone()) - if lr_guide is not None and (inner_cnt == 2 ) : - - guide = self.agg_net[stage_cnt](x.shape, x, lr_guide[stage_cnt], r_emb_lite) - x = x + guide - - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - else: - x = block(x) - if i < len(repmap): - x = repmap[i](x) - level_outputs.insert(0, x) # 0 indicate last output - if require_q: - return level_outputs, qs - return level_outputs - - - def _up_decode(self, level_outputs, r_embed, clip, cnet=None, require_ff=False, agg_f=None, r_emb_lite=None, guide_weight=1): - if require_ff: - agg_feas = [] - x = level_outputs[0] - block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) - for i, (up_block, upscaler, repmap) in enumerate(block_group): - for j in range(len(repmap) + 1): - for k, block in enumerate(up_block): - - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - skip = level_outputs[i] if k == 0 and i > 0 else None - - - if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): - x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', - align_corners=True) - - if cnet is not None and agg_f is None: - next_cnet = cnet() - if next_cnet is not None: - - x = x + nn.functional.interpolate(next_cnet.float(), size=x.shape[-2:], mode='bilinear', - align_corners=True) - - - x = block(x, skip) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - - - x = block(x, clip) - if require_ff and (k == 2 ): - agg_feas.append(x.clone()) - if agg_f is not None and (k == 2 ) : - - guide = self.agg_net_up[i](x.shape, x, agg_f[i], r_emb_lite) # training 1 test 4k 0.8 2k 0.7 - if not self.training: - hw = x.shape[-2] * x.shape[-1] - if hw >= 96*96: - guide = 0.7*guide - - else: - - if hw >= 72*72: - guide = 0.5* guide - else: - - guide = 0.3* guide - - x = x + guide - - - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - #if require_ff: - # agg_feas.append(x.clone()) - else: - x = block(x) - if j < len(repmap): - x = repmap[j](x) - x = upscaler(x) - - - if require_ff: - return x, agg_feas - - return x - - - - - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, lr_guide=None, reuire_f=False, cnet=None, require_t=False, guide_weight=0.5, **kwargs): - - r_embed = self.gen_r_embedding(r) - - for c in self.t_conds: - t_cond = kwargs.get(c, torch.zeros_like(r)) - r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) - clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) - - # Model Blocks - - x = self.embedding(x) - - - - if cnet is not None: - cnet = ControlNetDeliverer(cnet) - - if not reuire_f: - level_outputs = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, \ - require_q=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) - x = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, \ - require_ff=reuire_f, r_emb_lite=self.gen_r_embedding(r), guide_weight=guide_weight) - else: - level_outputs, lr_enc = self._down_encode(x, r_embed, clip, cnet, lr_guide= lr_guide[0] if lr_guide is not None else None, require_q=True) - x, lr_dec = self._up_decode(level_outputs, r_embed, clip, cnet, agg_f=lr_guide[1] if lr_guide is not None else None, require_ff=True) - - if reuire_f and require_t: - return self.clf(x), r_embed, lr_enc, lr_dec - if reuire_f: - return self.clf(x), lr_enc, lr_dec - if require_t: - return self.clf(x), r_embed - return self.clf(x) - - - def update_weights_ema(self, src_model, beta=0.999): - for self_params, src_params in zip(self.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) - for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): - self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) - - - -if __name__ == '__main__': - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - total_ori = sum([ param.nelement() for param in generator.parameters()]) - generator._init_extra_parameter() - generator = generator.cuda() - total = sum([ param.nelement() for param in generator.parameters()]) - total_down = sum([ param.nelement() for param in generator.down_blocks.parameters()]) - - total_up = sum([ param.nelement() for param in generator.up_blocks.parameters()]) - total_pro = sum([ param.nelement() for param in generator.project.parameters()]) - - - print(total_ori / 1e6, total / 1e6, total_up / 1e6, total_down / 1e6, total_pro / 1e6) - - # for name, module in generator.down_blocks.named_modules(): - # print(name, module) - output, out_lr = generator( - x=torch.randn(1, 16, 24, 24).cuda(), - x_lr=torch.randn(1, 16, 16, 16).cuda(), - r=torch.tensor([0.7056]).cuda(), - clip_text=torch.randn(1, 77, 1280).cuda(), - clip_text_pooled = torch.randn(1, 1, 1280).cuda(), - clip_img = torch.randn(1, 1, 768).cuda() - ) - print(output.shape, out_lr.shape) - # cnt diff --git a/modules/previewer.py b/modules/previewer.py deleted file mode 100644 index 51ab24292d8ac0da8d24b17d8fc0ac9e1419a3d7..0000000000000000000000000000000000000000 --- a/modules/previewer.py +++ /dev/null @@ -1,45 +0,0 @@ -from torch import nn - - -# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 -class Previewer(nn.Module): - def __init__(self, c_in=16, c_hidden=512, c_out=3): - super().__init__() - self.blocks = nn.Sequential( - nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels - nn.GELU(), - nn.BatchNorm2d(c_hidden), - - nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), - nn.GELU(), - nn.BatchNorm2d(c_hidden), - - nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 - nn.GELU(), - nn.BatchNorm2d(c_hidden // 2), - - nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), - nn.GELU(), - nn.BatchNorm2d(c_hidden // 2), - - nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 - nn.GELU(), - nn.BatchNorm2d(c_hidden // 4), - - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), - nn.GELU(), - nn.BatchNorm2d(c_hidden // 4), - - nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 - nn.GELU(), - nn.BatchNorm2d(c_hidden // 4), - - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), - nn.GELU(), - nn.BatchNorm2d(c_hidden // 4), - - nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), - ) - - def forward(self, x): - return self.blocks(x) diff --git a/modules/resnet.py b/modules/resnet.py deleted file mode 100644 index 460a808942be147d76b8b1f3baf29fec1e2a7b8d..0000000000000000000000000000000000000000 --- a/modules/resnet.py +++ /dev/null @@ -1,415 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -#import fvcore.nn.weight_init as weight_init - -""" -Functions for building the BottleneckBlock from Detectron2. -# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py -""" - -def get_norm(norm, out_channels, num_norm_groups=32): - """ - Args: - norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; - or a callable that takes a channel number and returns - the normalization layer as a nn.Module. - Returns: - nn.Module or None: the normalization layer - """ - if norm is None: - return None - if isinstance(norm, str): - if len(norm) == 0: - return None - norm = { - "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels), - }[norm] - return norm(out_channels) - -class Conv2d(nn.Conv2d): - """ - A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. - """ - - def __init__(self, *args, **kwargs): - """ - Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: - Args: - norm (nn.Module, optional): a normalization layer - activation (callable(Tensor) -> Tensor): a callable activation function - It assumes that norm layer is used before activation. - """ - norm = kwargs.pop("norm", None) - activation = kwargs.pop("activation", None) - super().__init__(*args, **kwargs) - - self.norm = norm - self.activation = activation - - def forward(self, x): - x = F.conv2d( - x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - if self.norm is not None: - x = self.norm(x) - if self.activation is not None: - x = self.activation(x) - return x - -class CNNBlockBase(nn.Module): - """ - A CNN block is assumed to have input channels, output channels and a stride. - The input and output of `forward()` method must be NCHW tensors. - The method can perform arbitrary computation but must match the given - channels and stride specification. - Attribute: - in_channels (int): - out_channels (int): - stride (int): - """ - - def __init__(self, in_channels, out_channels, stride): - """ - The `__init__` method of any subclass should also contain these arguments. - Args: - in_channels (int): - out_channels (int): - stride (int): - """ - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = stride - -class BottleneckBlock(CNNBlockBase): - """ - The standard bottleneck residual block used by ResNet-50, 101 and 152 - defined in :paper:`ResNet`. It contains 3 conv layers with kernels - 1x1, 3x3, 1x1, and a projection shortcut if needed. - """ - - def __init__( - self, - in_channels, - out_channels, - *, - bottleneck_channels, - stride=1, - num_groups=1, - norm="GN", - stride_in_1x1=False, - dilation=1, - num_norm_groups=32 - ): - """ - Args: - bottleneck_channels (int): number of output channels for the 3x3 - "bottleneck" conv layers. - num_groups (int): number of groups for the 3x3 conv layer. - norm (str or callable): normalization for all conv layers. - See :func:`layers.get_norm` for supported format. - stride_in_1x1 (bool): when stride>1, whether to put stride in the - first 1x1 convolution or the bottleneck 3x3 convolution. - dilation (int): the dilation rate of the 3x3 conv layer. - """ - super().__init__(in_channels, out_channels, stride) - - if in_channels != out_channels: - self.shortcut = Conv2d( - in_channels, - out_channels, - kernel_size=1, - stride=stride, - bias=False, - norm=get_norm(norm, out_channels, num_norm_groups), - ) - else: - self.shortcut = None - - # The original MSRA ResNet models have stride in the first 1x1 conv - # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have - # stride in the 3x3 conv - stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) - - self.conv1 = Conv2d( - in_channels, - bottleneck_channels, - kernel_size=1, - stride=stride_1x1, - bias=False, - norm=get_norm(norm, bottleneck_channels, num_norm_groups), - ) - - self.conv2 = Conv2d( - bottleneck_channels, - bottleneck_channels, - kernel_size=3, - stride=stride_3x3, - padding=1 * dilation, - bias=False, - groups=num_groups, - dilation=dilation, - norm=get_norm(norm, bottleneck_channels, num_norm_groups), - ) - - self.conv3 = Conv2d( - bottleneck_channels, - out_channels, - kernel_size=1, - bias=False, - norm=get_norm(norm, out_channels, num_norm_groups), - ) - - #for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: - # if layer is not None: # shortcut can be None - # weight_init.c2_msra_fill(layer) - - # Zero-initialize the last normalization in each residual branch, - # so that at the beginning, the residual branch starts with zeros, - # and each residual block behaves like an identity. - # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": - # "For BN layers, the learnable scaling coefficient ¦Ã is initialized - # to be 1, except for each residual block's last BN - # where ¦Ã is initialized to be 0." - - # nn.init.constant_(self.conv3.norm.weight, 0) - # TODO this somehow hurts performance when training GN models from scratch. - # Add it as an option when we need to use this code to train a backbone. - - def forward(self, x): - out = self.conv1(x) - out = F.relu_(out) - - out = self.conv2(out) - out = F.relu_(out) - - out = self.conv3(out) - - if self.shortcut is not None: - shortcut = self.shortcut(x) - else: - shortcut = x - - out += shortcut - out = F.relu_(out) - return out - -class ResNet(nn.Module): - """ - Implement :paper:`ResNet`. - """ - - def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): - """ - Args: - stem (nn.Module): a stem module - stages (list[list[CNNBlockBase]]): several (typically 4) stages, - each contains multiple :class:`CNNBlockBase`. - num_classes (None or int): if None, will not perform classification. - Otherwise, will create a linear layer. - out_features (list[str]): name of the layers whose outputs should - be returned in forward. Can be anything in "stem", "linear", or "res2" ... - If None, will return the output of the last layer. - freeze_at (int): The number of stages at the beginning to freeze. - see :meth:`freeze` for detailed explanation. - """ - super().__init__() - self.stem = stem - self.num_classes = num_classes - - current_stride = self.stem.stride - self._out_feature_strides = {"stem": current_stride} - self._out_feature_channels = {"stem": self.stem.out_channels} - - self.stage_names, self.stages = [], [] - - if out_features is not None: - # Avoid keeping unused layers in this module. They consume extra memory - # and may cause allreduce to fail - num_stages = max( - [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] - ) - stages = stages[:num_stages] - for i, blocks in enumerate(stages): - assert len(blocks) > 0, len(blocks) - for block in blocks: - assert isinstance(block, CNNBlockBase), block - - name = "res" + str(i + 2) - stage = nn.Sequential(*blocks) - - self.add_module(name, stage) - self.stage_names.append(name) - self.stages.append(stage) - - self._out_feature_strides[name] = current_stride = int( - current_stride * np.prod([k.stride for k in blocks]) - ) - self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels - self.stage_names = tuple(self.stage_names) # Make it static for scripting - - if num_classes is not None: - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.linear = nn.Linear(curr_channels, num_classes) - - # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": - # "The 1000-way fully-connected layer is initialized by - # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." - nn.init.normal_(self.linear.weight, std=0.01) - name = "linear" - - if out_features is None: - out_features = [name] - self._out_features = out_features - assert len(self._out_features) - children = [x[0] for x in self.named_children()] - for out_feature in self._out_features: - assert out_feature in children, "Available children: {}".format(", ".join(children)) - self.freeze(freeze_at) - - def forward(self, x): - """ - Args: - x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. - Returns: - dict[str->Tensor]: names and the corresponding features - """ - assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" - outputs = {} - x = self.stem(x) - if "stem" in self._out_features: - outputs["stem"] = x - for name, stage in zip(self.stage_names, self.stages): - x = stage(x) - if name in self._out_features: - outputs[name] = x - if self.num_classes is not None: - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.linear(x) - if "linear" in self._out_features: - outputs["linear"] = x - return outputs - - def freeze(self, freeze_at=0): - """ - Freeze the first several stages of the ResNet. Commonly used in - fine-tuning. - Layers that produce the same feature map spatial size are defined as one - "stage" by :paper:`FPN`. - Args: - freeze_at (int): number of stages to freeze. - `1` means freezing the stem. `2` means freezing the stem and - one residual stage, etc. - Returns: - nn.Module: this ResNet itself - """ - if freeze_at >= 1: - self.stem.freeze() - for idx, stage in enumerate(self.stages, start=2): - if freeze_at >= idx: - for block in stage.children(): - block.freeze() - return self - - @staticmethod - def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): - """ - Create a list of blocks of the same type that forms one ResNet stage. - Args: - block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this - stage. A module of this type must not change spatial resolution of inputs unless its - stride != 1. - num_blocks (int): number of blocks in this stage - in_channels (int): input channels of the entire stage. - out_channels (int): output channels of **every block** in the stage. - kwargs: other arguments passed to the constructor of - `block_class`. If the argument name is "xx_per_block", the - argument is a list of values to be passed to each block in the - stage. Otherwise, the same argument is passed to every block - in the stage. - Returns: - list[CNNBlockBase]: a list of block module. - Examples: - :: - stage = ResNet.make_stage( - BottleneckBlock, 3, in_channels=16, out_channels=64, - bottleneck_channels=16, num_groups=1, - stride_per_block=[2, 1, 1], - dilations_per_block=[1, 1, 2] - ) - Usually, layers that produce the same feature map spatial size are defined as one - "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should - all be 1. - """ - blocks = [] - for i in range(num_blocks): - curr_kwargs = {} - for k, v in kwargs.items(): - if k.endswith("_per_block"): - assert len(v) == num_blocks, ( - f"Argument '{k}' of make_stage should have the " - f"same length as num_blocks={num_blocks}." - ) - newk = k[: -len("_per_block")] - assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" - curr_kwargs[newk] = v[i] - else: - curr_kwargs[k] = v - - blocks.append( - block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) - ) - in_channels = out_channels - return blocks - - @staticmethod - def make_default_stages(depth, block_class=None, **kwargs): - """ - Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). - If it doesn't create the ResNet variant you need, please use :meth:`make_stage` - instead for fine-grained customization. - Args: - depth (int): depth of ResNet - block_class (type): the CNN block class. Has to accept - `bottleneck_channels` argument for depth > 50. - By default it is BasicBlock or BottleneckBlock, based on the - depth. - kwargs: - other arguments to pass to `make_stage`. Should not contain - stride and channels, as they are predefined for each depth. - Returns: - list[list[CNNBlockBase]]: modules in all stages; see arguments of - :class:`ResNet.__init__`. - """ - num_blocks_per_stage = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3], - }[depth] - if block_class is None: - block_class = BasicBlock if depth < 50 else BottleneckBlock - if depth < 50: - in_channels = [64, 64, 128, 256] - out_channels = [64, 128, 256, 512] - else: - in_channels = [64, 256, 512, 1024] - out_channels = [256, 512, 1024, 2048] - ret = [] - for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): - if depth >= 50: - kwargs["bottleneck_channels"] = o // 4 - ret.append( - ResNet.make_stage( - block_class=block_class, - num_blocks=n, - stride_per_block=[s] + [1] * (n - 1), - in_channels=i, - out_channels=o, - **kwargs, - ) - ) - return ret \ No newline at end of file diff --git a/modules/speed_util.py b/modules/speed_util.py deleted file mode 100644 index 3b9507c74833bec270b00bd252a3c76fcc09fab3..0000000000000000000000000000000000000000 --- a/modules/speed_util.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled()} - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) \ No newline at end of file diff --git a/modules/stage_a.py b/modules/stage_a.py deleted file mode 100644 index 2840ef71d30e3da74954ab4a05e724fd7fef86cf..0000000000000000000000000000000000000000 --- a/modules/stage_a.py +++ /dev/null @@ -1,183 +0,0 @@ -import torch -from torch import nn -from torchtools.nn import VectorQuantize -from einops import rearrange -import torch.nn.functional as F -import math -class ResBlock(nn.Module): - def __init__(self, c, c_hidden): - super().__init__() - # depthwise/attention - self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.depthwise = nn.Sequential( - nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) - ) - - # channelwise - self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) - self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), - nn.GELU(), - nn.Linear(c_hidden, c), - ) - - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - def _norm(self, x, norm): - return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - def forward(self, x): - - mods = self.gammas - - x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] - - #x = x.to(torch.float64) - x = x + self.depthwise(x_temp) * mods[2] - - x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] - x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] - - return x - - -def extract_patches(tensor, patch_size, stride): - b, c, H, W = tensor.shape - pad_h = (patch_size - (H - patch_size) % stride) % stride - pad_w = (patch_size - (W - patch_size) % stride) % stride - tensor = F.pad(tensor, (0, pad_w, 0, pad_h), mode='reflect') - - - patches = tensor.unfold(2, patch_size, stride).unfold(3, patch_size, stride) - patches = patches.contiguous().view(b, c, -1, patch_size, patch_size) - patches = patches.permute(0, 2, 1, 3, 4) - return patches, (H, W) - -def fuse_patches(patches, patch_size, stride, H, W): - - b, num_patches, c, _, _ = patches.shape - patches = patches.permute(0, 2, 1, 3, 4) - - - - pad_h = (patch_size - (H - patch_size) % stride) % stride - pad_w = (patch_size - (W - patch_size) % stride) % stride - out_h = H + pad_h - out_w = W + pad_w - patches = patches.contiguous().view(b, c , -1, patch_size*patch_size ).permute(0, 1, 3, 2) - patches = patches.contiguous().view(b, c*patch_size*patch_size, -1) - - tensor = F.fold(patches, output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) - overlap_cnt = F.fold(torch.ones_like(patches), output_size=(out_h, out_w), kernel_size=patch_size, stride=stride) - tensor = tensor / overlap_cnt - print('end fuse patch', tensor.shape, (tensor.dtype)) - return tensor[:, :, :H, :W] - - - -class StageA(nn.Module): - def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, - scale_factor=0.43): # 0.3764 - super().__init__() - self.c_latent = c_latent - self.scale_factor = scale_factor - c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] - - # Encoder blocks - self.in_block = nn.Sequential( - nn.PixelUnshuffle(2), - nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) - ) - down_blocks = [] - for i in range(levels): - if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) - block = ResBlock(c_levels[i], c_levels[i] * 4) - down_blocks.append(block) - down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), - nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 - )) - self.down_blocks = nn.Sequential(*down_blocks) - self.down_blocks[0] - - self.codebook_size = codebook_size - self.vquantizer = VectorQuantize(c_latent, k=codebook_size) - - # Decoder blocks - up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) - )] - for i in range(levels): - for j in range(bottleneck_blocks if i == 0 else 1): - block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) - up_blocks.append(block) - if i < levels - 1: - up_blocks.append( - nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, - padding=1)) - self.up_blocks = nn.Sequential(*up_blocks) - self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), - nn.PixelShuffle(2), - ) - - def encode(self, x, quantize=False): - x = self.in_block(x) - x = self.down_blocks(x) - if quantize: - qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) - return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 - else: - return x / self.scale_factor, None, None, None - - - - def decode(self, x, tiled_decoding=False): - x = x * self.scale_factor - x = self.up_blocks(x) - x = self.out_block(x) - return x - - def forward(self, x, quantize=False): - qe, x, _, vq_loss = self.encode(x, quantize) - x = self.decode(qe) - return x, vq_loss - - -class Discriminator(nn.Module): - def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): - super().__init__() - d = max(depth - 3, 3) - layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), - nn.LeakyReLU(0.2), - ] - for i in range(depth - 1): - c_in = c_hidden // (2 ** max((d - i), 0)) - c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) - layers.append(nn.InstanceNorm2d(c_out)) - layers.append(nn.LeakyReLU(0.2)) - self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) - self.logits = nn.Sigmoid() - - def forward(self, x, cond=None): - x = self.encoder(x) - if cond is not None: - cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) - x = torch.cat([x, cond], dim=1) - x = self.shuffle(x) - x = self.logits(x) - return x diff --git a/modules/stage_b.py b/modules/stage_b.py deleted file mode 100644 index f89b42d61327278820e164b1c093cbf8d1048ee1..0000000000000000000000000000000000000000 --- a/modules/stage_b.py +++ /dev/null @@ -1,239 +0,0 @@ -import math -import numpy as np -import torch -from torch import nn -from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock - - -class StageB(nn.Module): - def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], - nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], - block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, - c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, - t_conds=['sca']): - super().__init__() - self.c_r = c_r - self.t_conds = t_conds - self.c_clip_seq = c_clip_seq - if not isinstance(dropout, list): - dropout = [dropout] * len(c_hidden) - if not isinstance(self_attn, list): - self_attn = [self_attn] * len(c_hidden) - - # CONDITIONING - self.effnet_mapper = nn.Sequential( - nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), - nn.GELU(), - nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) - ) - self.pixels_mapper = nn.Sequential( - nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), - nn.GELU(), - nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) - ) - self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) - self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) - - self.embedding = nn.Sequential( - nn.PixelUnshuffle(patch_size), - nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) - ) - - def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): - if block_type == 'C': - return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) - elif block_type == 'A': - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) - elif block_type == 'F': - return FeedForwardBlock(c_hidden, dropout=dropout) - elif block_type == 'T': - return TimestepBlock(c_hidden, c_r, conds=t_conds) - else: - raise Exception(f'Block type {block_type} not supported') - - # BLOCKS - # -- down blocks - self.down_blocks = nn.ModuleList() - self.down_downscalers = nn.ModuleList() - self.down_repeat_mappers = nn.ModuleList() - for i in range(len(c_hidden)): - if i > 0: - self.down_downscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), - nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), - )) - else: - self.down_downscalers.append(nn.Identity()) - down_block = nn.ModuleList() - for _ in range(blocks[0][i]): - for block_type in level_config[i]: - block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) - down_block.append(block) - self.down_blocks.append(down_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[0][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.down_repeat_mappers.append(block_repeat_mappers) - - # -- up blocks - self.up_blocks = nn.ModuleList() - self.up_upscalers = nn.ModuleList() - self.up_repeat_mappers = nn.ModuleList() - for i in reversed(range(len(c_hidden))): - if i > 0: - self.up_upscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), - nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), - )) - else: - self.up_upscalers.append(nn.Identity()) - up_block = nn.ModuleList() - for j in range(blocks[1][::-1][i]): - for k, block_type in enumerate(level_config[i]): - c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 - block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], - self_attn=self_attn[i]) - up_block.append(block) - self.up_blocks.append(up_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[1][::-1][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.up_repeat_mappers.append(block_repeat_mappers) - - # OUTPUT - self.clf = nn.Sequential( - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), - nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), - nn.PixelShuffle(patch_size), - ) - - # --- WEIGHT INIT --- - self.apply(self._init_weights) # General init - nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings - nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings - nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings - nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings - nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings - torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs - nn.init.constant_(self.clf[1].weight, 0) # outputs - - # blocks - for level_block in self.down_blocks + self.up_blocks: - for block in level_block: - if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): - block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) - elif isinstance(block, TimestepBlock): - for layer in block.modules(): - if isinstance(layer, nn.Linear): - nn.init.constant_(layer.weight, 0) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - torch.nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode='constant') - return emb - - def gen_c_embeddings(self, clip): - if len(clip.shape) == 2: - clip = clip.unsqueeze(1) - clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) - clip = self.clip_norm(clip) - return clip - - def _down_encode(self, x, r_embed, clip): - level_outputs = [] - block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) - for down_block, downscaler, repmap in block_group: - x = downscaler(x) - for i in range(len(repmap) + 1): - for block in down_block: - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - x = block(x) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - x = block(x, clip) - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - else: - x = block(x) - if i < len(repmap): - x = repmap[i](x) - level_outputs.insert(0, x) - return level_outputs - - def _up_decode(self, level_outputs, r_embed, clip): - x = level_outputs[0] - block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) - for i, (up_block, upscaler, repmap) in enumerate(block_group): - for j in range(len(repmap) + 1): - for k, block in enumerate(up_block): - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - skip = level_outputs[i] if k == 0 and i > 0 else None - if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): - x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', - align_corners=True) - x = block(x, skip) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - x = block(x, clip) - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - else: - x = block(x) - if j < len(repmap): - x = repmap[j](x) - x = upscaler(x) - return x - - def forward(self, x, r, effnet, clip, pixels=None, **kwargs): - if pixels is None: - pixels = x.new_zeros(x.size(0), 3, 8, 8) - - # Process the conditioning embeddings - r_embed = self.gen_r_embedding(r) - for c in self.t_conds: - t_cond = kwargs.get(c, torch.zeros_like(r)) - r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) - clip = self.gen_c_embeddings(clip) - - # Model Blocks - x = self.embedding(x) - x = x + self.effnet_mapper( - nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode='bilinear', align_corners=True)) - x = x + nn.functional.interpolate(self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode='bilinear', - align_corners=True) - level_outputs = self._down_encode(x, r_embed, clip) - x = self._up_decode(level_outputs, r_embed, clip) - return self.clf(x) - - def update_weights_ema(self, src_model, beta=0.999): - for self_params, src_params in zip(self.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) - for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): - self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/modules/stage_c.py b/modules/stage_c.py deleted file mode 100644 index 53b73d0197712b981ec1a154428c21af2149646a..0000000000000000000000000000000000000000 --- a/modules/stage_c.py +++ /dev/null @@ -1,252 +0,0 @@ -import torch -from torch import nn -import numpy as np -import math -from .common import AttnBlock, LayerNorm2d, ResBlock, FeedForwardBlock, TimestepBlock -#from .controlnet import ControlNetDeliverer - - -class UpDownBlock2d(nn.Module): - def __init__(self, c_in, c_out, mode, enabled=True): - super().__init__() - assert mode in ['up', 'down'] - interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', - align_corners=True) if enabled else nn.Identity() - mapping = nn.Conv2d(c_in, c_out, kernel_size=1) - self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) - - def forward(self, x): - for block in self.blocks: - x = block(x.float()) - return x - - -class StageC(nn.Module): - def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], - blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], - c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, - dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False]): - super().__init__() - self.c_r = c_r - self.t_conds = t_conds - self.c_clip_seq = c_clip_seq - if not isinstance(dropout, list): - dropout = [dropout] * len(c_hidden) - if not isinstance(self_attn, list): - self_attn = [self_attn] * len(c_hidden) - - # CONDITIONING - self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) - self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) - self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) - self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) - - self.embedding = nn.Sequential( - nn.PixelUnshuffle(patch_size), - nn.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1), - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6) - ) - - def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): - if block_type == 'C': - return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) - elif block_type == 'A': - return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) - elif block_type == 'F': - return FeedForwardBlock(c_hidden, dropout=dropout) - elif block_type == 'T': - return TimestepBlock(c_hidden, c_r, conds=t_conds) - else: - raise Exception(f'Block type {block_type} not supported') - - # BLOCKS - # -- down blocks - self.down_blocks = nn.ModuleList() - self.down_downscalers = nn.ModuleList() - self.down_repeat_mappers = nn.ModuleList() - for i in range(len(c_hidden)): - if i > 0: - self.down_downscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), - UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1]) - )) - else: - self.down_downscalers.append(nn.Identity()) - down_block = nn.ModuleList() - for _ in range(blocks[0][i]): - for block_type in level_config[i]: - block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) - down_block.append(block) - self.down_blocks.append(down_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[0][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.down_repeat_mappers.append(block_repeat_mappers) - - # -- up blocks - self.up_blocks = nn.ModuleList() - self.up_upscalers = nn.ModuleList() - self.up_repeat_mappers = nn.ModuleList() - for i in reversed(range(len(c_hidden))): - if i > 0: - self.up_upscalers.append(nn.Sequential( - LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), - UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1]) - )) - else: - self.up_upscalers.append(nn.Identity()) - up_block = nn.ModuleList() - for j in range(blocks[1][::-1][i]): - for k, block_type in enumerate(level_config[i]): - c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 - block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], - self_attn=self_attn[i]) - up_block.append(block) - self.up_blocks.append(up_block) - if block_repeat is not None: - block_repeat_mappers = nn.ModuleList() - for _ in range(block_repeat[1][::-1][i] - 1): - block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) - self.up_repeat_mappers.append(block_repeat_mappers) - - # OUTPUT - self.clf = nn.Sequential( - LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), - nn.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1), - nn.PixelShuffle(patch_size), - ) - - # --- WEIGHT INIT --- - self.apply(self._init_weights) # General init - nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings - nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings - nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings - torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs - nn.init.constant_(self.clf[1].weight, 0) # outputs - - # blocks - for level_block in self.down_blocks + self.up_blocks: - for block in level_block: - if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): - block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) - elif isinstance(block, TimestepBlock): - for layer in block.modules(): - if isinstance(layer, nn.Linear): - nn.init.constant_(layer.weight, 0) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - torch.nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def gen_r_embedding(self, r, max_positions=10000): - r = r * max_positions - half_dim = self.c_r // 2 - emb = math.log(max_positions) / (half_dim - 1) - emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() - emb = r[:, None] * emb[None, :] - emb = torch.cat([emb.sin(), emb.cos()], dim=1) - if self.c_r % 2 == 1: # zero pad - emb = nn.functional.pad(emb, (0, 1), mode='constant') - return emb - - def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): - clip_txt = self.clip_txt_mapper(clip_txt) - if len(clip_txt_pooled.shape) == 2: - clip_txt_pool = clip_txt_pooled.unsqueeze(1) - if len(clip_img.shape) == 2: - clip_img = clip_img.unsqueeze(1) - clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) - clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) - clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) - clip = self.clip_norm(clip) - return clip - - def _down_encode(self, x, r_embed, clip, cnet=None): - level_outputs = [] - block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) - for down_block, downscaler, repmap in block_group: - x = downscaler(x) - for i in range(len(repmap) + 1): - for block in down_block: - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - if cnet is not None: - next_cnet = cnet() - if next_cnet is not None: - x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) - x = block(x) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - x = block(x, clip) - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - else: - x = block(x) - if i < len(repmap): - x = repmap[i](x) - level_outputs.insert(0, x) - return level_outputs - - def _up_decode(self, level_outputs, r_embed, clip, cnet=None): - x = level_outputs[0] - block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) - for i, (up_block, upscaler, repmap) in enumerate(block_group): - for j in range(len(repmap) + 1): - for k, block in enumerate(up_block): - if isinstance(block, ResBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - ResBlock)): - skip = level_outputs[i] if k == 0 and i > 0 else None - if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): - x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode='bilinear', - align_corners=True) - if cnet is not None: - next_cnet = cnet() - if next_cnet is not None: - x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) - x = block(x, skip) - elif isinstance(block, AttnBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - AttnBlock)): - x = block(x, clip) - elif isinstance(block, TimestepBlock) or ( - hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, - TimestepBlock)): - x = block(x, r_embed) - else: - x = block(x) - if j < len(repmap): - x = repmap[j](x) - x = upscaler(x) - return x - - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): - # Process the conditioning embeddings - r_embed = self.gen_r_embedding(r) - for c in self.t_conds: - t_cond = kwargs.get(c, torch.zeros_like(r)) - r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) - clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) - - # Model Blocks - x = self.embedding(x) - if cnet is not None: - cnet = ControlNetDeliverer(cnet) - level_outputs = self._down_encode(x, r_embed, clip, cnet) - x = self._up_decode(level_outputs, r_embed, clip, cnet) - return self.clf(x) - - def update_weights_ema(self, src_model, beta=0.999): - for self_params, src_params in zip(self.parameters(), src_model.parameters()): - self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) - for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): - self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/prompt_list.txt b/prompt_list.txt deleted file mode 100644 index 27cd31b4750d2f15fdb6f2a3f4bdd117a7377267..0000000000000000000000000000000000000000 --- a/prompt_list.txt +++ /dev/null @@ -1,32 +0,0 @@ -A close-up of a blooming peony, with layers of soft, pink petals, a delicate fragrance, and dewdrops glistening -in the early morning light. - -A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a -clear blue sky. - -A close-up portrait of a young woman with flawless skin, vibrant red lipstick, and wavy brown hair, wearing -a vintage floral dress and standing in front of a blooming garden. - -The image features a snow-covered mountain range with a large, snow-covered mountain in the background. -The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the -winter season, with snow covering the ground and the trees. - -Crocodile in a sweater. - -A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school -uniform, standing under a cherry blossom tree with petals falling around her. The background shows a -traditional Japanese school with cherry blossoms in full bloom. - -A playful Labrador retriever puppy with a shiny, golden coat, chasing a red ball in a spacious backyard, with -green grass and a wooden fence. - -A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm -lights glowing from the windows, and a path of footprints leading to the front door. - -A highly detailed, high-quality image of the Banff National Park in Canada. The turquoise waters of Lake -Louise are surrounded by snow-capped mountains and dense pine forests. A wooden canoe is docked at the -edge of the lake. The sky is a clear, bright blue, and the air is crisp and fresh. - -A highly detailed, high-quality image of a Shih Tzu receiving a bath in a home bathroom. The dog is standing -in a tub, covered in suds, with a slightly wet and adorable look. The background includes bathroom fixtures, -towels, and a clean, tiled floor. \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 431ebbcb0dc492a0b05801e3f9d0f96efdd27245..1270d9d1c13425922f21302c1724fcd0e133a8b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,3 @@ ---find-links https://download.pytorch.org/whl/torch_stable.html -accelerate>=0.25.0 -torch==2.1.2 -torchvision==0.16.2 -transformers>=4.30.0 -numpy==1.26.4 -kornia>=0.7.0 -insightface>=0.7.3 -opencv-python>=4.8.1.78 -tqdm>=4.66.1 -matplotlib>=3.7.4 -webdataset>=0.2.79 -wandb>=0.16.2 -munch>=4.0.0 -onnxruntime>=1.16.3 -einops>=0.7.0 -onnx2torch>=1.5.13 -warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git -torchtools @ git+https://github.com/pabloppp/pytorch-tools +timm +transformers +spaces \ No newline at end of file diff --git a/train/__init__.py b/train/__init__.py deleted file mode 100644 index ea1331f6b933f63c99a6bdf074201fdb4b8f78c2..0000000000000000000000000000000000000000 --- a/train/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .train_b import WurstCore as WurstCoreB -from .train_c import WurstCore as WurstCoreC -from .train_t2i import WurstCore as WurstCore_t2i -from .train_ultrapixel_control import WurstCore as WurstCore_control_lrguide -from .train_personalized import WurstCore as WurstCore_personalized \ No newline at end of file diff --git a/train/base.py b/train/base.py deleted file mode 100644 index 4e8a6ef306e40da8c9d8db33ceba2f8b2982a9a9..0000000000000000000000000000000000000000 --- a/train/base.py +++ /dev/null @@ -1,402 +0,0 @@ -import yaml -import json -import torch -import wandb -import torchvision -import numpy as np -from torch import nn -from tqdm import tqdm -from abc import abstractmethod -from fractions import Fraction -import matplotlib.pyplot as plt -from dataclasses import dataclass -from torch.distributed import barrier -from torch.utils.data import DataLoader - -from gdf import GDF -from gdf import AdaptiveLossWeight - -from core import WarpCore -from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer -from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary - -import webdataset as wds -from webdataset.handlers import warn_and_continue - -import transformers -transformers.utils.logging.set_verbosity_error() - - -class DataCore(WarpCore): - @dataclass(frozen=True) - class Config(WarpCore.Config): - image_size: int = EXPECTED_TRAIN - webdataset_path: str = EXPECTED_TRAIN - grad_accum_steps: int = EXPECTED_TRAIN - batch_size: int = EXPECTED_TRAIN - multi_aspect_ratio: list = None - - captions_getter: list = None - dataset_filters: list = None - - bucketeer_random_ratio: float = 0.05 - - @dataclass(frozen=True) - class Extras(WarpCore.Extras): - transforms: torchvision.transforms.Compose = EXPECTED - clip_preprocess: torchvision.transforms.Compose = EXPECTED - - @dataclass(frozen=True) - class Models(WarpCore.Models): - tokenizer: nn.Module = EXPECTED - text_model: nn.Module = EXPECTED - image_model: nn.Module = None - - config: Config - - def webdataset_path(self): - if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith( - 'pipe:') or self.config.webdataset_path.strip().startswith('file:')): - return self.config.webdataset_path - else: - dataset_path = self.config.webdataset_path - if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'): - with open(self.config.webdataset_path, 'r', encoding='utf-8') as file: - dataset_path = yaml.safe_load(file) - return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml") - - def webdataset_preprocessors(self, extras: Extras): - def identity(x): - if isinstance(x, bytes): - x = x.decode('utf-8') - return x - - # CUSTOM CAPTIONS GETTER ----- - def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption - if p_og > 0 and np.random.rand() < p_og and len(oc) > 0: - return identity(oc) - else: - return identity(c) - - captions_getter = MultiGetter(rules={ - ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05) - }) - - return [ - ('jpg;png', - torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms, - 'images'), - ('txt', identity, 'captions') if self.config.captions_getter is None else ( - self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'), - ] - - def setup_data(self, extras: Extras) -> WarpCore.Data: - # SETUP DATASET - dataset_path = self.webdataset_path() - preprocessors = self.webdataset_preprocessors(extras) - - handler = warn_and_continue - dataset = wds.WebDataset( - dataset_path, resampled=True, handler=handler - ).select( - MultiFilter(rules={ - f[0]: eval(f[1]) for f in self.config.dataset_filters - }) if self.config.dataset_filters is not None else lambda _: True - ).shuffle(690, handler=handler).decode( - "pilrgb", handler=handler - ).to_tuple( - *[p[0] for p in preprocessors], handler=handler - ).map_tuple( - *[p[1] for p in preprocessors], handler=handler - ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) - - def identity(x): - return x - - # SETUP DATALOADER - real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, - collate_fn=identity if self.config.multi_aspect_ratio is not None else None - ) - if self.is_main_node: - print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") - - if self.config.multi_aspect_ratio is not None: - aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] - dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32, - ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, - interpolate_nearest=False) # , use_smartcrop=True) - else: - dataloader_iterator = iter(dataloader) - - return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator) - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - if return_fields is None: - return_fields = ['clip_text', 'clip_text_pooled', 'clip_img'] - - captions = batch.get('captions', None) - images = batch.get('images', None) - batch_size = len(captions) - - text_embeddings = None - text_pooled_embeddings = None - if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields: - if is_eval: - if is_unconditional: - captions_unpooled = ["" for _ in range(batch_size)] - else: - captions_unpooled = captions - else: - rand_idx = np.random.rand(batch_size) > 0.05 - captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)] - clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length", - max_length=models.tokenizer.model_max_length, - return_tensors="pt").to(self.device) - text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True) - if 'clip_text' in return_fields: - text_embeddings = text_encoder_output.hidden_states[-1] - if 'clip_text_pooled' in return_fields: - text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) - - image_embeddings = None - if 'clip_img' in return_fields: - image_embeddings = torch.zeros(batch_size, 768, device=self.device) - if images is not None: - images = images.to(self.device) - if is_eval: - if not is_unconditional and eval_image_embeds: - image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds - else: - rand_idx = np.random.rand(batch_size) > 0.9 - if any(rand_idx): - image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds - image_embeddings = image_embeddings.unsqueeze(1) - return { - 'clip_text': text_embeddings, - 'clip_text_pooled': text_pooled_embeddings, - 'clip_img': image_embeddings - } - - -class TrainingCore(DataCore, WarpCore): - @dataclass(frozen=True) - class Config(DataCore.Config, WarpCore.Config): - updates: int = EXPECTED_TRAIN - backup_every: int = EXPECTED_TRAIN - save_every: int = EXPECTED_TRAIN - - # EMA UPDATE - ema_start_iters: int = None - ema_iters: int = None - ema_beta: float = None - - use_fsdp: bool = None - - @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED - class Info(WarpCore.Info): - ema_loss: float = None - adaptive_loss: dict = None - - @dataclass(frozen=True) - class Models(WarpCore.Models): - generator: nn.Module = EXPECTED - generator_ema: nn.Module = None # optional - - @dataclass(frozen=True) - class Optimizers(WarpCore.Optimizers): - generator: any = EXPECTED - - @dataclass(frozen=True) - class Extras(WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - - info: Info - config: Config - - @abstractmethod - def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, - schedulers: WarpCore.Schedulers): - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def models_to_save(self) -> list: - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - raise NotImplementedError("This method needs to be overriden") - - @abstractmethod - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - raise NotImplementedError("This method needs to be overriden") - - def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, - schedulers: WarpCore.Schedulers): - start_iter = self.info.iter + 1 - max_iters = self.config.updates * self.config.grad_accum_steps - if self.is_main_node: - print(f"STARTING AT STEP: {start_iter}/{max_iters}") - - pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter, - max_iters + 1) # <--- DDP - if 'generator' in self.models_to_save(): - models.generator.train() - for i in pbar: - # FORWARD PASS - loss, loss_adjusted = self.forward_pass(data, extras, models) - - # # BACKWARD PASS - grad_norm = self.backward_pass( - i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted, - models, optimizers, schedulers - ) - self.info.iter = i - - # UPDATE EMA - if models.generator_ema is not None and i % self.config.ema_iters == 0: - update_weights_ema( - models.generator_ema, models.generator, - beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) - ) - - # UPDATE LOSS METRICS - self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 - - if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan( - grad_norm.item()): - wandb.alert( - title=f"NaN value encountered in training run {self.info.wandb_run_id}", - text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", - wait_duration=60 * 30 - ) - - if self.is_main_node: - logs = { - 'loss': self.info.ema_loss, - 'raw_loss': loss.mean().item(), - 'grad_norm': grad_norm.item(), - 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, - 'total_steps': self.info.total_steps, - } - - pbar.set_postfix(logs) - if self.config.wandb_project is not None: - wandb.log(logs) - - if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters: - # SAVE AND CHECKPOINT STUFF - if np.isnan(loss.mean().item()): - if self.is_main_node and self.config.wandb_project is not None: - tqdm.write("Skipping sampling & checkpoint because the loss is NaN") - wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}", - text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") - else: - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - self.info.adaptive_loss = { - 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), - 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), - } - self.save_checkpoints(models, optimizers) - if self.is_main_node: - create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') - self.sample(models, data, extras) - - def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() - suffix = '' if suffix is None else suffix - self.save_info(self.info, suffix=suffix) - models_dict = models.to_dict() - optimizers_dict = optimizers.to_dict() - for key in self.models_to_save(): - model = models_dict[key] - if model is not None: - self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) - for key in optimizers_dict: - optimizer = optimizers_dict[key] - if optimizer is not None: - self.save_optimizer(optimizer, f'{key}_optim{suffix}', - fsdp_model=models_dict[key] if self.config.use_fsdp else None) - if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: - self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") - torch.cuda.empty_cache() - - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - if 'generator' in self.models_to_save(): - models.generator.eval() - with torch.no_grad(): - batch = next(data.iterator) - - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - latents = self.encode_latents(batch, models, extras) - noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred = models.generator(noised, noise_cond, **conditions) - pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - *_, (sampled, _, _) = extras.gdf.sample( - models.generator, conditions, - latents.shape, unconditions, device=self.device, **extras.sampling_configs - ) - - if models.generator_ema is not None: - *_, (sampled_ema, _, _) = extras.gdf.sample( - models.generator_ema, conditions, - latents.shape, unconditions, device=self.device, **extras.sampling_configs - ) - else: - sampled_ema = sampled - - if self.is_main_node: - noised_images = torch.cat( - [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) - pred_images = torch.cat( - [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) - sampled_images = torch.cat( - [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) - sampled_images_ema = torch.cat( - [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], - dim=0) - - images = batch['images'] - if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): - images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') - - collage_img = torch.cat([ - torch.cat([i for i in images.cpu()], dim=-1), - torch.cat([i for i in noised_images.cpu()], dim=-1), - torch.cat([i for i in pred_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), - ], dim=-2) - - torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') - torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') - - captions = batch['captions'] - if self.config.wandb_project is not None: - log_data = [ - [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ - wandb.Image(images[i])] for i in range(len(images))] - log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) - wandb.log({"Log": log_table}) - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) - plt.ylabel('Raw Loss') - plt.ylabel('LogSNR') - wandb.log({"Loss/LogSRN": plt}) - - if 'generator' in self.models_to_save(): - models.generator.train() diff --git a/train/dist_core.py b/train/dist_core.py deleted file mode 100644 index 4e4e9e670a3b853fac345618d3557d648d813902..0000000000000000000000000000000000000000 --- a/train/dist_core.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import torch - - -def get_world_size(): - """Find OMPI world size without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_SIZE') is not None: - return int(os.environ.get('PMI_SIZE') or 1) - elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) - else: - return torch.cuda.device_count() - - -def get_global_rank(): - """Find OMPI world rank without calling mpi functions - :rtype: int - """ - if os.environ.get('PMI_RANK') is not None: - return int(os.environ.get('PMI_RANK') or 0) - elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) - else: - return 0 - - -def get_local_rank(): - """Find OMPI local rank without calling mpi functions - :rtype: int - """ - if os.environ.get('MPI_LOCALRANKID') is not None: - return int(os.environ.get('MPI_LOCALRANKID') or 0) - elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: - return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) - else: - return 0 - - -def get_master_ip(): - if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] - elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: - return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') - else: - return "127.0.0.1" diff --git a/train/train_b.py b/train/train_b.py deleted file mode 100644 index c3441a5841750a7c33b49756d2d60064a68d82d8..0000000000000000000000000000000000000000 --- a/train/train_b.py +++ /dev/null @@ -1,305 +0,0 @@ -import torch -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection -from warmup_scheduler import GradualWarmupScheduler -import numpy as np - -import sys -import os -from dataclasses import dataclass - -from gdf import GDF, EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop - -from modules.effnet import EfficientNetEncoder -from modules.stage_a import StageA - -from modules.stage_b import StageB -from modules.stage_b import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock - -from train.base import DataCore, TrainingCore - -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail - -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager - -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - shift: float = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3BB or 700M - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - stage_a_checkpoint_path: str = EXPECTED - effnet_checkpoint_path: str = EXPECTED - generator_checkpoint_path: str = None - - # gdf customization - adaptive_loss_weight: str = None - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - stage_a: nn.Module = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - - info: TrainingCore.Info - config: Config - - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 1.5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 10} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size, - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) if self.config.training else torchvision.transforms.CenterCrop(self.config.image_size) - ]) - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=None - ) - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None): - images = batch.get('images', None) - - if images is not None: - images = images.to(self.device) - if is_eval and not is_unconditional: - effnet_embeddings = models.effnet(extras.effnet_preprocess(images)) - else: - if is_eval: - effnet_factor = 1 - else: - effnet_factor = np.random.uniform(0.5, 1) # f64 to f32 - effnet_height, effnet_width = int(((images.size(-2)*effnet_factor)//32)*32), int(((images.size(-1)*effnet_factor)//32)*32) - - effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height//32, effnet_width//32, device=self.device) - if not is_eval: - effnet_images = torchvision.transforms.functional.resize(images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST) - rand_idx = np.random.rand(len(images)) <= 0.9 - if any(rand_idx): - effnet_embeddings[rand_idx] = models.effnet(extras.effnet_preprocess(effnet_images[rand_idx])) - else: - effnet_embeddings = None - - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text_pooled'] - ) - - return {'effnet': effnet_embeddings, 'clip': conditions['clip_text_pooled']} - - def setup_models(self, extras: Extras, skip_clip: bool = False) -> Models: - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 - - # EfficientNet encoder - effnet = EfficientNetEncoder().to(self.device) - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False) - del effnet_checkpoint - - # vqGAN - stage_a = StageA().to(self.device) - stage_a_checkpoint = load_or_fail(self.config.stage_a_checkpoint_path) - stage_a.load_state_dict(stage_a_checkpoint if 'state_dict' not in stage_a_checkpoint else stage_a_checkpoint['state_dict']) - stage_a.eval().requires_grad_(False) - del stage_a_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - # Diffusion models - with loading_context(): - generator_ema = None - if self.config.model_version == '3B': - generator = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) - if self.config.ema_start_iters is not None: - generator_ema = StageB(c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]]) - elif self.config.model_version == '700M': - generator = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) - if self.config.ema_start_iters is not None: - generator_ema = StageB(c_hidden=[320, 576, 1152, 1152], nhead=[-1, 9, 18, 18], blocks=[[2, 4, 14, 4], [4, 14, 4, 2]], block_repeat=[[1, 1, 1, 1], [2, 2, 2, 2]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - if self.config.generator_checkpoint_path is not None: - if loading_context is dummy_context: - generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - else: - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - generator = generator.to(dtype).to(self.device) - generator = self.load_model(generator, 'generator') - - if generator_ema is not None: - if loading_context is dummy_context: - generator_ema.load_state_dict(generator.state_dict()) - else: - for param_name, param in generator.state_dict().items(): - set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) - generator_ema = self.load_model(generator_ema, 'generator_ema') - generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) - - if self.config.use_fsdp: - fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) - generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - if generator_ema is not None: - generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - - if skip_clip: - tokenizer = None - text_model = None - else: - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - - return self.Models( - effnet=effnet, stage_a=stage_a, - generator=generator, generator_ema=generator_ema, - tokenizer=tokenizer, text_model=text_model - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) - optimizer = self.load_optimizer(optimizer, 'generator_optim', - fsdp_model=models.generator if self.config.use_fsdp else None) - return self.Optimizers(generator=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, - optimizers: TrainingCore.Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(generator=scheduler) - - def _pyramid_noise(self, epsilon, size_range=None, levels=10, scale_mode='nearest'): - epsilon = epsilon.clone() - multipliers = [1] - for i in range(1, levels): - m = 0.75 ** i - h, w = epsilon.size(-2) // (2 ** i), epsilon.size(-2) // (2 ** i) - if size_range is None or (size_range[0] <= h <= size_range[1] or size_range[0] <= w <= size_range[1]): - offset = torch.randn(epsilon.size(0), epsilon.size(1), h, w, device=self.device) - epsilon = epsilon + torch.nn.functional.interpolate(offset, size=epsilon.shape[-2:], - mode=scale_mode) * m - multipliers.append(m) - if h <= 1 or w <= 1: - break - epsilon = epsilon / sum([m ** 2 for m in multipliers]) ** 0.5 - # epsilon = epsilon / epsilon.std() - return epsilon - - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - batch = next(data.iterator) - - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - latents = self.encode_latents(batch, models, extras) - epsilon = torch.randn_like(latents) - epsilon = self._pyramid_noise(epsilon, size_range=[1, 16]) - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1, - epsilon=epsilon) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred = models.generator(noised, noise_cond, **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - - return loss, loss_adjusted - - def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, - schedulers: Schedulers): - if update: - loss_adjusted.backward() - grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - loss_adjusted.backward() - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - def models_to_save(self): - return ['generator', 'generator_ema'] - - def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - images = batch['images'].to(self.device) - return models.stage_a.encode(images)[0] - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.stage_a.decode(latents.float()).clamp(0, 1) - - -if __name__ == '__main__': - print("Launching Script") - warpcore = WurstCore( - config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) - ) - # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore() diff --git a/train/train_c.py b/train/train_c.py deleted file mode 100644 index c4490c6eebc3e1c5126dd13c53603872f1459a3e..0000000000000000000000000000000000000000 --- a/train/train_c.py +++ /dev/null @@ -1,266 +0,0 @@ -import torch -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection -from warmup_scheduler import GradualWarmupScheduler - -import sys -import os -from dataclasses import dataclass - -from gdf import GDF, EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop - -from modules.effnet import EfficientNetEncoder -from modules.stage_c import StageC -from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from modules.previewer import Previewer - -from train.base import DataCore, TrainingCore - -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail - -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager - -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3.6B or 1B - clip_image_model_name: str = 'openai/clip-vit-large-patch14' - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - effnet_checkpoint_path: str = EXPECTED - previewer_checkpoint_path: str = EXPECTED - generator_checkpoint_path: str = None - - # gdf customization - adaptive_loss_weight: str = None - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - previewer: nn.Module = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - - info: TrainingCore.Info - config: Config - - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - clip_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) - ) - ]) - - if self.config.training: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) - ]) - else: - transforms = None - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=clip_preprocess - ) - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] - ) - return conditions - - def setup_models(self, extras: Extras) -> Models: - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 - - # EfficientNet encoder - effnet = EfficientNetEncoder() - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False).to(self.device) - del effnet_checkpoint - - # Previewer - previewer = Previewer() - previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) - previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) - previewer.eval().requires_grad_(False).to(self.device) - del previewer_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - # Diffusion models - with loading_context(): - generator_ema = None - if self.config.model_version == '3.6B': - generator = StageC() - if self.config.ema_start_iters is not None: - generator_ema = StageC() - elif self.config.model_version == '1B': - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - if self.config.ema_start_iters is not None: - generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - if self.config.generator_checkpoint_path is not None: - if loading_context is dummy_context: - generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - else: - - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - generator = generator.to(dtype).to(self.device) - generator = self.load_model(generator, 'generator') - - if generator_ema is not None: - if loading_context is dummy_context: - generator_ema.load_state_dict(generator.state_dict()) - else: - for param_name, param in generator.state_dict().items(): - set_module_tensor_to_device(generator_ema, param_name, "cpu", value=param) - generator_ema = self.load_model(generator_ema, 'generator_ema') - generator_ema.to(dtype).to(self.device).eval().requires_grad_(False) - - if self.config.use_fsdp: - fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock]) - generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - if generator_ema is not None: - generator_ema = FSDP(generator_ema, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) - - return self.Models( - effnet=effnet, previewer=previewer, - generator=generator, generator_ema=generator_ema, - tokenizer=tokenizer, text_model=text_model, image_model=image_model - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) - optimizer = self.load_optimizer(optimizer, 'generator_optim', - fsdp_model=models.generator if self.config.use_fsdp else None) - return self.Optimizers(generator=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(generator=scheduler) - - # Training loop -------------------------------- - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - batch = next(data.iterator) - - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - latents = self.encode_latents(batch, models, extras) - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred = models.generator(noised, noise_cond, **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - - return loss, loss_adjusted - - def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): - if update: - loss_adjusted.backward() - grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - loss_adjusted.backward() - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - def models_to_save(self): - return ['generator', 'generator_ema'] - - def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - images = batch['images'].to(self.device) - return models.effnet(extras.effnet_preprocess(images)) - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.previewer(latents) - - -if __name__ == '__main__': - print("Launching Script") - warpcore = WurstCore( - config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) - ) - # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore() diff --git a/train/train_c_lora.py b/train/train_c_lora.py deleted file mode 100644 index 8b83eee0f250e5359901d39b8d4052254cfff4fa..0000000000000000000000000000000000000000 --- a/train/train_c_lora.py +++ /dev/null @@ -1,330 +0,0 @@ -import torch -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection -from warmup_scheduler import GradualWarmupScheduler - -import sys -import os -import re -from dataclasses import dataclass - -from gdf import GDF, EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop - -from modules.effnet import EfficientNetEncoder -from modules.stage_c import StageC -from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from modules.previewer import Previewer -from modules.lora import apply_lora, apply_retoken, LoRA, ReToken - -from train.base import DataCore, TrainingCore - -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail - -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -import functools -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager - - -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3.6B or 1B - clip_image_model_name: str = 'openai/clip-vit-large-patch14' - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - effnet_checkpoint_path: str = EXPECTED - previewer_checkpoint_path: str = EXPECTED - generator_checkpoint_path: str = None - lora_checkpoint_path: str = None - - # LoRA STUFF - module_filters: list = EXPECTED - rank: int = EXPECTED - train_tokens: list = EXPECTED - - # gdf customization - adaptive_loss_weight: str = None - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - previewer: nn.Module = EXPECTED - lora: nn.Module = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - lora: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - - @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED - class Info(TrainingCore.Info): - train_tokens: list = None - - @dataclass(frozen=True) - class Optimizers(TrainingCore.Optimizers, WarpCore.Optimizers): - generator: any = None - lora: any = EXPECTED - - # -------------------------------------------- - info: Info - config: Config - - # Extras: gdf, transforms and preprocessors -------------------------------- - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - clip_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) - ) - ]) - - if self.config.training: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) - ]) - else: - transforms = None - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=clip_preprocess - ) - - # Data -------------------------------- - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] - ) - return conditions - - # Models, Optimizers & Schedulers setup -------------------------------- - def setup_models(self, extras: Extras) -> Models: - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.float32 - - # EfficientNet encoder - effnet = EfficientNetEncoder().to(self.device) - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False) - del effnet_checkpoint - - # Previewer - previewer = Previewer().to(self.device) - previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) - previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) - previewer.eval().requires_grad_(False) - del previewer_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - with loading_context(): - # Diffusion models - if self.config.model_version == '3.6B': - generator = StageC() - elif self.config.model_version == '1B': - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - if self.config.generator_checkpoint_path is not None: - if loading_context is dummy_context: - generator.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - else: - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - generator = generator.to(dtype).to(self.device) - generator = self.load_model(generator, 'generator') - - # if self.config.use_fsdp: - # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) - # generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - - # CLIP encoders - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) - - # PREPARE LORA - update_tokens = [] - for tkn_regex, aggr_regex in self.config.train_tokens: - if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): - # Insert new token - tokenizer.add_tokens([tkn_regex]) - # add new zeros embedding - new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] - if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline - aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] - if len(aggr_tokens) > 0: - new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) - elif self.is_main_node: - print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") - text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ - text_model.text_model.embeddings.token_embedding.weight.data, new_embedding - ], dim=0) - selected_tokens = [len(tokenizer.vocab) - 1] - else: - selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] - update_tokens += selected_tokens - update_tokens = list(set(update_tokens)) # remove duplicates - - apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) - apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) - text_model.text_model.to(self.device) - generator.to(self.device) - lora = nn.ModuleDict() - lora['embeddings'] = text_model.text_model.embeddings.token_embedding.parametrizations.weight[0] - lora['weights'] = nn.ModuleList() - for module in generator.modules(): - if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): - lora['weights'].append(module) - - self.info.train_tokens = [(i, tokenizer.decode(i)) for i in update_tokens] - if self.is_main_node: - print("Updating tokens:", self.info.train_tokens) - print(f"LoRA training {len(lora['weights'])} layers") - - if self.config.lora_checkpoint_path is not None: - lora_checkpoint = load_or_fail(self.config.lora_checkpoint_path) - lora.load_state_dict(lora_checkpoint if 'state_dict' not in lora_checkpoint else lora_checkpoint['state_dict']) - - lora = self.load_model(lora, 'lora') - lora.to(self.device).train().requires_grad_(True) - if self.config.use_fsdp: - # fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000) - fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken]) - lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device) - - return self.Models( - effnet=effnet, previewer=previewer, - generator=generator, generator_ema=None, - lora=lora, - tokenizer=tokenizer, text_model=text_model, image_model=image_model - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: - optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95)) - optimizer = self.load_optimizer(optimizer, 'lora_optim', - fsdp_model=models.lora if self.config.use_fsdp else None) - return self.Optimizers(generator=None, lora=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.lora, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(lora=scheduler) - - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - batch = next(data.iterator) - - conditions = self.get_conditions(batch, models, extras) - with torch.no_grad(): - latents = self.encode_latents(batch, models, extras) - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred = models.generator(noised, noise_cond, **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - - return loss, loss_adjusted - - def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): - if update: - loss_adjusted.backward() - grad_norm = nn.utils.clip_grad_norm_(models.lora.parameters(), 1.0) - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if optimizers_dict[k] is not None and k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if optimizers_dict[k] is not None and k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - loss_adjusted.backward() - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - def models_to_save(self): - return ['lora'] - - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - models.lora.eval() - super().sample(models, data, extras) - models.lora.train(), models.generator.eval() - - def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - images = batch['images'].to(self.device) - return models.effnet(extras.effnet_preprocess(images)) - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.previewer(latents) - - -if __name__ == '__main__': - print("Launching Script") - warpcore = WurstCore( - config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) - ) - warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore() diff --git a/train/train_personalized.py b/train/train_personalized.py deleted file mode 100644 index 5161b7c621a0eb9daf9d0f0566322bbeed646284..0000000000000000000000000000000000000000 --- a/train/train_personalized.py +++ /dev/null @@ -1,899 +0,0 @@ -import torch -import json -import yaml -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection -from warmup_scheduler import GradualWarmupScheduler -import torch.multiprocessing as mp -import os -import numpy as np -import re -import sys -sys.path.append(os.path.abspath('./')) - -from dataclasses import dataclass -from torch.distributed import init_process_group, destroy_process_group, barrier -from gdf import GDF_dual_fixlrt as GDF -from gdf import EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop -from fractions import Fraction -from modules.effnet import EfficientNetEncoder -from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from modules.common_ckpt import GlobalResponseNorm -from modules.previewer import Previewer -from core.data import Bucketeer -from train.base import DataCore, TrainingCore -from tqdm import tqdm -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail - -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager -from train.dist_core import * -import glob -from torch.utils.data import DataLoader, Dataset -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from PIL import Image -from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary -from core.utils import Base -import torch.nn.functional as F -import functools -import math -import copy -import random -from modules.lora import apply_lora, apply_retoken, LoRA, ReToken - -Image.MAX_IMAGE_PIXELS = None -torch.manual_seed(23) -random.seed(23) -np.random.seed(23) -#7978026 - -class Null_Model(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - pass - - - - -def identity(x): - if isinstance(x, bytes): - x = x.decode('utf-8') - return x -def check_nan_inmodel(model, meta=''): - for name, param in model.named_parameters(): - if torch.isnan(param).any(): - print(f"nan detected in {name}", meta) - return True - print('no nan', meta) - return False -class mydist_dataset(Dataset): - def __init__(self, rootpath, tmp_prompt, img_processor=None): - - self.img_pathlist = glob.glob(os.path.join(rootpath, '*.jpg')) - self.img_pathlist = self.img_pathlist * 100000 - self.img_processor = img_processor - self.length = len( self.img_pathlist) - self.caption = tmp_prompt - - - def __getitem__(self, idx): - - imgpath = self.img_pathlist[idx] - txt = self.caption - - - - - try: - img = Image.open(imgpath).convert('RGB') - w, h = img.size - if self.img_processor is not None: - img = self.img_processor(img) - - except: - print('exception', imgpath) - return self.__getitem__(random.randint(0, self.length -1 ) ) - return dict(captions=txt, images=img) - def __len__(self): - return self.length -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3.6B or 1B - clip_image_model_name: str = 'openai/clip-vit-large-patch14' - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - effnet_checkpoint_path: str = EXPECTED - previewer_checkpoint_path: str = EXPECTED - generator_checkpoint_path: str = None - ultrapixel_path: str = EXPECTED - - # gdf customization - adaptive_loss_weight: str = None - - # LoRA STUFF - module_filters: list = EXPECTED - rank: int = EXPECTED - train_tokens: list = EXPECTED - use_ddp: bool=EXPECTED - tmp_prompt: str=EXPECTED - @dataclass(frozen=True) - class Data(Base): - dataset: Dataset = EXPECTED - dataloader: DataLoader = EXPECTED - iterator: any = EXPECTED - sampler: DistributedSampler = EXPECTED - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - previewer: nn.Module = EXPECTED - train_norm: nn.Module = EXPECTED - train_lora: nn.Module = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - - info: TrainingCore.Info - config: Config - - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - clip_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) - ) - ]) - - if self.config.training: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) - ]) - else: - transforms = None - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=clip_preprocess - ) - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] - ) - return conditions - - def setup_models(self, extras: Extras) -> Models: # configure model - - - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 - - # EfficientNet encoderin - effnet = EfficientNetEncoder() - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False).to(self.device) - del effnet_checkpoint - - # Previewer - previewer = Previewer() - previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) - previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) - previewer.eval().requires_grad_(False).to(self.device) - del previewer_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - # Diffusion models - with loading_context(): - generator_ema = None - if self.config.model_version == '3.6B': - generator = StageC() - if self.config.ema_start_iters is not None: # default setting - generator_ema = StageC() - elif self.config.model_version == '1B': - print('in line 155 1b light model', self.config.model_version ) - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - - if self.config.ema_start_iters is not None and self.config.training: - generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - - - if loading_context is dummy_context: - generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) - else: - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - - generator._init_extra_parameter() - generator = generator.to(torch.bfloat16).to(self.device) - - train_norm = nn.ModuleList() - - - cnt_norm = 0 - for mm in generator.modules(): - if isinstance(mm, GlobalResponseNorm): - - train_norm.append(Null_Model()) - cnt_norm += 1 - - - - - train_norm.append(generator.agg_net) - train_norm.append(generator.agg_net_up) - sdd = torch.load(self.config.ultrapixel_path, map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - train_norm.load_state_dict(collect_sd) - - - - # CLIP encoders - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) - - # PREPARE LORA - train_lora = nn.ModuleList() - update_tokens = [] - for tkn_regex, aggr_regex in self.config.train_tokens: - if (tkn_regex.startswith('[') and tkn_regex.endswith(']')) or (tkn_regex.startswith('<') and tkn_regex.endswith('>')): - # Insert new token - tokenizer.add_tokens([tkn_regex]) - # add new zeros embedding - new_embedding = torch.zeros_like(text_model.text_model.embeddings.token_embedding.weight.data)[:1] - if aggr_regex is not None: # aggregate embeddings to provide an interesting baseline - aggr_tokens = [v for k, v in tokenizer.vocab.items() if re.search(aggr_regex, k) is not None] - if len(aggr_tokens) > 0: - new_embedding = text_model.text_model.embeddings.token_embedding.weight.data[aggr_tokens].mean(dim=0, keepdim=True) - elif self.is_main_node: - print(f"WARNING: No tokens found for aggregation regex {aggr_regex}. It will be initialized as zeros.") - text_model.text_model.embeddings.token_embedding.weight.data = torch.cat([ - text_model.text_model.embeddings.token_embedding.weight.data, new_embedding - ], dim=0) - selected_tokens = [len(tokenizer.vocab) - 1] - else: - selected_tokens = [v for k, v in tokenizer.vocab.items() if re.search(tkn_regex, k) is not None] - update_tokens += selected_tokens - update_tokens = list(set(update_tokens)) # remove duplicates - - apply_retoken(text_model.text_model.embeddings.token_embedding, update_tokens) - - apply_lora(generator, filters=self.config.module_filters, rank=self.config.rank) - for module in generator.modules(): - if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)): - train_lora.append(module) - - - train_lora.append(text_model.text_model.embeddings.token_embedding.parametrizations.weight[0]) - - if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors')): - sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_lora.safetensors'), map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - train_lora.load_state_dict(collect_sd, strict=True) - - - train_norm.to(self.device).train().requires_grad_(True) - - if generator_ema is not None: - - generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - generator_ema._init_extra_parameter() - pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') - if os.path.exists(pretrained_pth): - generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) - - generator_ema.eval().requires_grad_(False) - - check_nan_inmodel(generator, 'generator') - - - - if self.config.use_ddp and self.config.training: - - train_lora = DDP(train_lora, device_ids=[self.device], find_unused_parameters=True) - - - - return self.Models( - effnet=effnet, previewer=previewer, train_norm = train_norm, - generator=generator, generator_ema=generator_ema, - tokenizer=tokenizer, text_model=text_model, image_model=image_model, - train_lora=train_lora - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - - - params = [] - params += list(models.train_lora.module.parameters()) - optimizer = optim.AdamW(params, lr=self.config.lr) - - return self.Optimizers(generator=optimizer) - - def ema_update(self, ema_model, source_model, beta): - for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): - param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) - - def sync_ema(self, ema_model): - print('sync ema', torch.distributed.get_world_size()) - for param in ema_model.parameters(): - torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) - param.data /= torch.distributed.get_world_size() - def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - - - optimizer = optim.AdamW( - models.generator.up_blocks.parameters() , - lr=self.config.lr) - optimizer = self.load_optimizer(optimizer, 'generator_optim', - fsdp_model=models.generator if self.config.use_fsdp else None) - return self.Optimizers(generator=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(generator=scheduler) - - def setup_data(self, extras: Extras) -> WarpCore.Data: - # SETUP DATASET - dataset_path = self.config.webdataset_path - - - dataset = mydist_dataset(dataset_path, self.config.tmp_prompt, \ - torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ - else extras.transforms) - - # SETUP DATALOADER - real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) - - sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, - collate_fn=identity if self.config.multi_aspect_ratio is not None else None, - sampler = sampler - ) - if self.is_main_node: - print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") - - if self.config.multi_aspect_ratio is not None: - aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] - dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, - ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, - interpolate_nearest=False) # , use_smartcrop=True) - else: - - dataloader_iterator = iter(dataloader) - - return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) - - - - - - def setup_ddp(self, experiment_id, single_gpu=False, rank=0): - - if not single_gpu: - local_rank = rank - process_id = rank - world_size = get_world_size() - - self.process_id = process_id - self.is_main_node = process_id == 0 - self.device = torch.device(local_rank) - self.world_size = world_size - - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '14443' - torch.cuda.set_device(local_rank) - init_process_group( - backend="nccl", - rank=local_rank, - world_size=world_size, - # init_method=init_method, - ) - print(f"[GPU {process_id}] READY") - else: - self.is_main_node = rank == 0 - self.process_id = rank - self.device = torch.device('cuda:0') - self.world_size = 1 - print("Running in single thread, DDP not enabled.") - # Training loop -------------------------------- - def get_target_lr_size(self, ratio, std_size=24): - w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) - return (h * 32 , w * 32) - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - - batch = data - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - shape_lr = self.get_target_lr_size(ratio) - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) - - - - flag_lr = random.random() < 0.5 or self.info.iter <5000 - - if flag_lr: - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1) - else: - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - if not flag_lr: - noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = \ - extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - - if not flag_lr: - with torch.no_grad(): - _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) - - - pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if not flag_lr else None , **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - - loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps - - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - return loss, loss_adjusted - - def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): - - if update: - - torch.distributed.barrier() - loss_adjusted.backward() - - grad_norm = nn.utils.clip_grad_norm_(models.train_lora.module.parameters(), 1.0) - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - - loss_adjusted.backward() - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - def models_to_save(self): - return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] - - def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: - - images = batch['images'].to(self.device) - if target_size is not None: - images = F.interpolate(images, target_size) - - return models.effnet(extras.effnet_preprocess(images)) - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.previewer(latents) - - def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): - - self.is_main_node = (rank == 0) - self.config: self.Config = self.setup_config(config_file_path, config_dict, training) - self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) - self.info: self.Info = self.setup_info() - print('in line 292', self.config.experiment_id, rank, world_size <= 1) - p = [i for i in range( 2 * 768 // 32)] - p = [num / sum(p) for num in p] - self.rand_pro = p - self.res_list = [o for o in range(800, 2336, 32)] - - - - def __call__(self, single_gpu=False): - - if self.config.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - if self.is_main_node: - print() - print("**STARTIG JOB WITH CONFIG:**") - print(yaml.dump(self.config.to_dict(), default_flow_style=False)) - print("------------------------------------") - print() - print("**INFO:**") - print(yaml.dump(vars(self.info), default_flow_style=False)) - print("------------------------------------") - print() - print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) - # SETUP STUFF - extras = self.setup_extras_pre() - assert extras is not None, "setup_extras_pre() must return a DTO" - - - - data = self.setup_data(extras) - assert data is not None, "setup_data() must return a DTO" - if self.is_main_node: - print("**DATA:**") - print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - models = self.setup_models(extras) - assert models is not None, "setup_models() must return a DTO" - if self.is_main_node: - print("**MODELS:**") - print(yaml.dump({ - k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() - }, default_flow_style=False)) - print("------------------------------------") - print() - - - - optimizers = self.setup_optimizers(extras, models) - assert optimizers is not None, "setup_optimizers() must return a DTO" - if self.is_main_node: - print("**OPTIMIZERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - schedulers = self.setup_schedulers(extras, models, optimizers) - assert schedulers is not None, "setup_schedulers() must return a DTO" - if self.is_main_node: - print("**SCHEDULERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) - assert post_extras is not None, "setup_extras_post() must return a DTO" - extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) - if self.is_main_node: - print("**EXTRAS:**") - print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - # ------- - - # TRAIN - if self.is_main_node: - print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) - - if single_gpu is False: - barrier() - destroy_process_group() - if self.is_main_node: - print() - print("------------------------------------") - print() - print("**TRAINING COMPLETE**") - if self.config.wandb_project is not None: - wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") - - - def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, - schedulers: WarpCore.Schedulers): - start_iter = self.info.iter + 1 - max_iters = self.config.updates * self.config.grad_accum_steps - if self.is_main_node: - print(f"STARTING AT STEP: {start_iter}/{max_iters}") - - - if self.is_main_node: - create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') - if 'generator' in self.models_to_save(): - models.generator.train() - - iter_cnt = 0 - epoch_cnt = 0 - models.train_norm.train() - while True: - epoch_cnt += 1 - if self.world_size > 1: - - data.sampler.set_epoch(epoch_cnt) - for ggg in range(len(data.dataloader)): - iter_cnt += 1 - # FORWARD PASS - - loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) - - - # # BACKWARD PASS - - grad_norm = self.backward_pass( - iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, - models, optimizers, schedulers - ) - - - - self.info.iter = iter_cnt - - - self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 - - - if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): - print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - if self.is_main_node: - logs = { - 'loss': self.info.ema_loss, - 'backward_loss': loss_adjusted.mean().item(), - - 'ema_loss': self.info.ema_loss, - 'raw_ori_loss': loss.mean().item(), - - 'grad_norm': grad_norm.item(), - 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, - 'total_steps': self.info.total_steps, - } - - - print(iter_cnt, max_iters, logs, epoch_cnt, ) - - - - - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: - - if np.isnan(loss.mean().item()): - if self.is_main_node and self.config.wandb_project is not None: - print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - else: - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - self.info.adaptive_loss = { - 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), - 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), - } - - - if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: - print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) - torch.save(models.train_lora.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_lora.safetensors') - - - torch.save(models.train_lora.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_lora_{iter_cnt}.safetensors') - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: - - if self.is_main_node: - - self.sample(models, data, extras) - if False: - param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} - threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% - important_params = [name for name, change in param_changes.items() if change > threshold] - print(important_params, threshold, len(param_changes), self.process_id) - json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) - - - if self.info.iter >= max_iters: - break - - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - - - models.generator.eval() - models.train_norm.eval() - with torch.no_grad(): - batch = next(data.iterator) - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - - shape_lr = self.get_target_lr_size(ratio) - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) - - if self.is_main_node: - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - - sampled_ema = sampled - sampled_ema_lr = sampled_lr - - - if self.is_main_node: - print('sampling results hr latent shape ', latents.shape, 'lr latent shape', latents_lr.shape, ) - noised_images = torch.cat( - [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) - - sampled_images = torch.cat( - [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) - sampled_images_ema = torch.cat( - [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], - dim=0) - - noised_images_lr = torch.cat( - [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) - - sampled_images_lr = torch.cat( - [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) - sampled_images_ema_lr = torch.cat( - [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], - dim=0) - - images = batch['images'] - if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): - images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') - images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') - - collage_img = torch.cat([ - torch.cat([i for i in images.cpu()], dim=-1), - torch.cat([i for i in noised_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), - ], dim=-2) - - collage_img_lr = torch.cat([ - torch.cat([i for i in images_lr.cpu()], dim=-1), - torch.cat([i for i in noised_images_lr.cpu()], dim=-1), - torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), - torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), - ], dim=-2) - - torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') - torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') - - captions = batch['captions'] - if self.config.wandb_project is not None: - log_data = [ - [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ - wandb.Image(images[i])] for i in range(len(images))] - log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) - wandb.log({"Log": log_table}) - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) - plt.ylabel('Raw Loss') - plt.ylabel('LogSNR') - wandb.log({"Loss/LogSRN": plt}) - - - models.generator.train() - models.train_norm.train() - print('finish sampling') - - - - def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): - - - models.generator.eval() - models.trans_inr.eval() - with torch.no_grad(): - - if self.is_main_node: - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, conditions, - hr_shape, lr_shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - if models.generator_ema is not None: - - *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( - models.generator_ema, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - else: - sampled_ema = sampled - sampled_ema_lr = sampled_lr - - - return sampled, sampled_lr -def main_worker(rank, cfg): - print("Launching Script in main worker") - warpcore = WurstCore( - config_file_path=cfg, rank=rank, world_size = get_world_size() - ) - # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore(get_world_size()==1) - -if __name__ == '__main__': - - if get_master_ip() == "127.0.0.1": - - mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) - else: - main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_t2i.py b/train/train_t2i.py deleted file mode 100644 index 456ca4b0dd1fe8e1fc18e3e5c940797439071d1f..0000000000000000000000000000000000000000 --- a/train/train_t2i.py +++ /dev/null @@ -1,807 +0,0 @@ -import torch -import json -import yaml -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection -from warmup_scheduler import GradualWarmupScheduler -import torch.multiprocessing as mp -import numpy as np -import os -import sys -sys.path.append(os.path.abspath('./')) -from dataclasses import dataclass -from torch.distributed import init_process_group, destroy_process_group, barrier -from gdf import GDF_dual_fixlrt as GDF -from gdf import EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop -from fractions import Fraction -from modules.effnet import EfficientNetEncoder - -from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from modules.previewer import Previewer -from core.data import Bucketeer -from train.base import DataCore, TrainingCore -from tqdm import tqdm -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail - -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager -from train.dist_core import * -import glob -from torch.utils.data import DataLoader, Dataset -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from PIL import Image -from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary -from core.utils import Base -from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm -import torch.nn.functional as F -import functools -import math -import copy -import random -from modules.lora import apply_lora, apply_retoken, LoRA, ReToken -Image.MAX_IMAGE_PIXELS = None -torch.manual_seed(23) -random.seed(23) -np.random.seed(23) -#7978026 - -class Null_Model(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - pass - - - - -def identity(x): - if isinstance(x, bytes): - x = x.decode('utf-8') - return x -def check_nan_inmodel(model, meta=''): - for name, param in model.named_parameters(): - if torch.isnan(param).any(): - print(f"nan detected in {name}", meta) - return True - print('no nan', meta) - return False -class mydist_dataset(Dataset): - def __init__(self, rootpath, img_processor=None): - - self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg')) - self.img_processor = img_processor - self.length = len( self.img_pathlist) - - - - def __getitem__(self, idx): - - imgpath = self.img_pathlist[idx] - json_file = imgpath.replace('.jpg', '.json') - - with open(json_file, 'r') as file: - info = json.load(file) - txt = info['caption'] - if txt is None: - txt = ' ' - try: - img = Image.open(imgpath).convert('RGB') - w, h = img.size - if self.img_processor is not None: - img = self.img_processor(img) - - except: - print('exception', imgpath) - return self.__getitem__(random.randint(0, self.length -1 ) ) - return dict(captions=txt, images=img) - def __len__(self): - return self.length - -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3.6B or 1B - clip_image_model_name: str = 'openai/clip-vit-large-patch14' - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - effnet_checkpoint_path: str = EXPECTED - previewer_checkpoint_path: str = EXPECTED - - generator_checkpoint_path: str = None - - # gdf customization - adaptive_loss_weight: str = None - use_ddp: bool=EXPECTED - - - @dataclass(frozen=True) - class Data(Base): - dataset: Dataset = EXPECTED - dataloader: DataLoader = EXPECTED - iterator: any = EXPECTED - sampler: DistributedSampler = EXPECTED - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - previewer: nn.Module = EXPECTED - train_norm: nn.Module = EXPECTED - - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - - info: TrainingCore.Info - config: Config - - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - clip_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) - ) - ]) - - if self.config.training: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) - ]) - else: - transforms = None - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=clip_preprocess - ) - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] - ) - return conditions - - def setup_models(self, extras: Extras) -> Models: # configure model - - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 - - # EfficientNet encoderin - effnet = EfficientNetEncoder() - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False).to(self.device) - del effnet_checkpoint - - # Previewer - previewer = Previewer() - previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) - previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) - previewer.eval().requires_grad_(False).to(self.device) - del previewer_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - # Diffusion models - with loading_context(): - generator_ema = None - if self.config.model_version == '3.6B': - generator = StageC() - if self.config.ema_start_iters is not None: # default setting - generator_ema = StageC() - elif self.config.model_version == '1B': - print('in line 155 1b light model', self.config.model_version ) - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - - if self.config.ema_start_iters is not None and self.config.training: - generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - - - if loading_context is dummy_context: - generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) - else: - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - - generator._init_extra_parameter() - generator = generator.to(torch.bfloat16).to(self.device) - - - train_norm = nn.ModuleList() - cnt_norm = 0 - for mm in generator.modules(): - if isinstance(mm, GlobalResponseNorm): - - train_norm.append(Null_Model()) - cnt_norm += 1 - - train_norm.append(generator.agg_net) - train_norm.append(generator.agg_net_up) - total = sum([ param.nelement() for param in train_norm.parameters()]) - print('Trainable parameter', total / 1048576) - - if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): - sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - train_norm.load_state_dict(collect_sd, strict=True) - - - train_norm.to(self.device).train().requires_grad_(True) - train_norm_ema = copy.deepcopy(train_norm) - train_norm_ema.to(self.device).eval().requires_grad_(False) - if generator_ema is not None: - - generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - generator_ema._init_extra_parameter() - - - pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') - if os.path.exists(pretrained_pth): - print(pretrained_pth, 'exists') - generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) - - - generator_ema.eval().requires_grad_(False) - - - - - check_nan_inmodel(generator, 'generator') - - - - if self.config.use_ddp and self.config.training: - - train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) - - # CLIP encoders - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) - - return self.Models( - effnet=effnet, previewer=previewer, train_norm = train_norm, - generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model, - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - - - params = [] - params += list(models.train_norm.module.parameters()) - - optimizer = optim.AdamW(params, lr=self.config.lr) - - return self.Optimizers(generator=optimizer) - - def ema_update(self, ema_model, source_model, beta): - for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): - param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) - - def sync_ema(self, ema_model): - for param in ema_model.parameters(): - torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) - param.data /= torch.distributed.get_world_size() - def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - - - optimizer = optim.AdamW( - models.generator.up_blocks.parameters() , - lr=self.config.lr) - optimizer = self.load_optimizer(optimizer, 'generator_optim', - fsdp_model=models.generator if self.config.use_fsdp else None) - return self.Optimizers(generator=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(generator=scheduler) - - def setup_data(self, extras: Extras) -> WarpCore.Data: - # SETUP DATASET - dataset_path = self.config.webdataset_path - dataset = mydist_dataset(dataset_path, \ - torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ - else extras.transforms) - - # SETUP DATALOADER - real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) - - sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, - collate_fn=identity if self.config.multi_aspect_ratio is not None else None, - sampler = sampler - ) - if self.is_main_node: - print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") - - if self.config.multi_aspect_ratio is not None: - aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] - dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, - ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, - interpolate_nearest=False) # , use_smartcrop=True) - else: - - dataloader_iterator = iter(dataloader) - - return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) - - - def models_to_save(self): - pass - def setup_ddp(self, experiment_id, single_gpu=False, rank=0): - - if not single_gpu: - local_rank = rank - process_id = rank - world_size = get_world_size() - - self.process_id = process_id - self.is_main_node = process_id == 0 - self.device = torch.device(local_rank) - self.world_size = world_size - - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '41443' - torch.cuda.set_device(local_rank) - init_process_group( - backend="nccl", - rank=local_rank, - world_size=world_size, - ) - print(f"[GPU {process_id}] READY") - else: - self.is_main_node = rank == 0 - self.process_id = rank - self.device = torch.device('cuda:0') - self.world_size = 1 - print("Running in single thread, DDP not enabled.") - # Training loop -------------------------------- - def get_target_lr_size(self, ratio, std_size=24): - w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) - return (h * 32 , w * 32) - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - #batch = next(data.iterator) - batch = data - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - shape_lr = self.get_target_lr_size(ratio) - #print('in line 485', shape_lr, ratio, batch['images'].shape) - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) - - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - # 768 1536 - require_cond = True - - with torch.no_grad(): - _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) - - - pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - - loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps - - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - - return loss, loss_adjusted - - def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): - - - if update: - - torch.distributed.barrier() - loss_adjusted.backward() - - grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) - - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - - loss_adjusted.backward() - - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - - def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: - - images = batch['images'].to(self.device) - if target_size is not None: - images = F.interpolate(images, target_size) - - return models.effnet(extras.effnet_preprocess(images)) - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.previewer(latents) - - def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): - - self.is_main_node = (rank == 0) - self.config: self.Config = self.setup_config(config_file_path, config_dict, training) - self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) - self.info: self.Info = self.setup_info() - - - - def __call__(self, single_gpu=False): - - if self.config.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - if self.is_main_node: - print() - print("**STARTIG JOB WITH CONFIG:**") - print(yaml.dump(self.config.to_dict(), default_flow_style=False)) - print("------------------------------------") - print() - print("**INFO:**") - print(yaml.dump(vars(self.info), default_flow_style=False)) - print("------------------------------------") - print() - - # SETUP STUFF - extras = self.setup_extras_pre() - assert extras is not None, "setup_extras_pre() must return a DTO" - - - - data = self.setup_data(extras) - assert data is not None, "setup_data() must return a DTO" - if self.is_main_node: - print("**DATA:**") - print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - models = self.setup_models(extras) - assert models is not None, "setup_models() must return a DTO" - if self.is_main_node: - print("**MODELS:**") - print(yaml.dump({ - k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() - }, default_flow_style=False)) - print("------------------------------------") - print() - - - - optimizers = self.setup_optimizers(extras, models) - assert optimizers is not None, "setup_optimizers() must return a DTO" - if self.is_main_node: - print("**OPTIMIZERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - schedulers = self.setup_schedulers(extras, models, optimizers) - assert schedulers is not None, "setup_schedulers() must return a DTO" - if self.is_main_node: - print("**SCHEDULERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) - assert post_extras is not None, "setup_extras_post() must return a DTO" - extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) - if self.is_main_node: - print("**EXTRAS:**") - print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - # ------- - - # TRAIN - if self.is_main_node: - print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) - - if single_gpu is False: - barrier() - destroy_process_group() - if self.is_main_node: - print() - print("------------------------------------") - print() - print("**TRAINING COMPLETE**") - - - - def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, - schedulers: WarpCore.Schedulers): - start_iter = self.info.iter + 1 - max_iters = self.config.updates * self.config.grad_accum_steps - if self.is_main_node: - print(f"STARTING AT STEP: {start_iter}/{max_iters}") - - - if self.is_main_node: - create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') - - models.generator.train() - - iter_cnt = 0 - epoch_cnt = 0 - models.train_norm.train() - while True: - epoch_cnt += 1 - if self.world_size > 1: - - data.sampler.set_epoch(epoch_cnt) - for ggg in range(len(data.dataloader)): - iter_cnt += 1 - loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) - grad_norm = self.backward_pass( - iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, - models, optimizers, schedulers - ) - - self.info.iter = iter_cnt - - - # UPDATE LOSS METRICS - self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 - - #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) - if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): - print(f" NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - if self.is_main_node: - logs = { - 'loss': self.info.ema_loss, - 'backward_loss': loss_adjusted.mean().item(), - 'ema_loss': self.info.ema_loss, - 'raw_ori_loss': loss.mean().item(), - 'grad_norm': grad_norm.item(), - 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, - 'total_steps': self.info.total_steps, - } - if iter_cnt % (self.config.save_every) == 0: - - print(iter_cnt, max_iters, logs, epoch_cnt, ) - - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: - - # SAVE AND CHECKPOINT STUFF - if np.isnan(loss.mean().item()): - if self.is_main_node and self.config.wandb_project is not None: - print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - else: - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - self.info.adaptive_loss = { - 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), - 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), - } - - - - if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: - print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) - torch.save(models.train_norm.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') - - torch.save(models.train_norm.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: - - if self.is_main_node: - - self.sample(models, data, extras) - - - if self.info.iter >= max_iters: - break - - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - - - models.generator.eval() - models.train_norm.eval() - with torch.no_grad(): - batch = next(data.iterator) - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - - shape_lr = self.get_target_lr_size(ratio) - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) - - - if self.is_main_node: - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - - - - if self.is_main_node: - print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, ) - noised_images = torch.cat( - [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) - - sampled_images = torch.cat( - [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) - - - noised_images_lr = torch.cat( - [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) - - sampled_images_lr = torch.cat( - [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) - - images = batch['images'] - if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): - images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') - images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') - - collage_img = torch.cat([ - torch.cat([i for i in images.cpu()], dim=-1), - torch.cat([i for i in noised_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images.cpu()], dim=-1), - ], dim=-2) - - collage_img_lr = torch.cat([ - torch.cat([i for i in images_lr.cpu()], dim=-1), - torch.cat([i for i in noised_images_lr.cpu()], dim=-1), - torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), - ], dim=-2) - - torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') - torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') - - - models.generator.train() - models.train_norm.train() - print('finish sampling') - - - - def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): - - - models.generator.eval() - - with torch.no_grad(): - - if self.is_main_node: - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, conditions, - hr_shape, lr_shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - if models.generator_ema is not None: - - *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( - models.generator_ema, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - else: - sampled_ema = sampled - sampled_ema_lr = sampled_lr - - return sampled, sampled_lr -def main_worker(rank, cfg): - print("Launching Script in main worker") - - warpcore = WurstCore( - config_file_path=cfg, rank=rank, world_size = get_world_size() - ) - # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore(get_world_size()==1) - -if __name__ == '__main__': - print('launch multi process') - # os.environ["OMP_NUM_THREADS"] = "1" - # os.environ["MKL_NUM_THREADS"] = "1" - #dist.init_process_group(backend="nccl") - #torch.backends.cudnn.benchmark = True -#train/train_c_my.py - #mp.set_sharing_strategy('file_system') - - if get_master_ip() == "127.0.0.1": - # manually launch distributed processes - mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) - else: - main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, ) diff --git a/train/train_ultrapixel_control.py b/train/train_ultrapixel_control.py deleted file mode 100644 index cd67965973a85ed1d72c164dd0e8970f8b5ce277..0000000000000000000000000000000000000000 --- a/train/train_ultrapixel_control.py +++ /dev/null @@ -1,928 +0,0 @@ -import torch -import json -import yaml -import torchvision -from torch import nn, optim -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection -from warmup_scheduler import GradualWarmupScheduler -import torch.multiprocessing as mp -import numpy as np -import sys - -import os -from dataclasses import dataclass -from torch.distributed import init_process_group, destroy_process_group, barrier -from gdf import GDF_dual_fixlrt as GDF -from gdf import EpsilonTarget, CosineSchedule -from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight -from torchtools.transforms import SmartCrop -from fractions import Fraction -from modules.effnet import EfficientNetEncoder - -from modules.model_4stage_lite import StageC - -from modules.model_4stage_lite import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock -from modules.common_ckpt import GlobalResponseNorm -from modules.previewer import Previewer -from core.data import Bucketeer -from train.base import DataCore, TrainingCore -from tqdm import tqdm -from core import WarpCore -from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail -from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from contextlib import contextmanager -from train.dist_core import * -import glob -from torch.utils.data import DataLoader, Dataset -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from PIL import Image -from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary -from core.utils import Base -from modules.common import LayerNorm2d -import torch.nn.functional as F -import functools -import math -import copy -import random -from modules.lora import apply_lora, apply_retoken, LoRA, ReToken -from modules import ControlNet, ControlNetDeliverer -from modules import controlnet_filters - -Image.MAX_IMAGE_PIXELS = None -torch.manual_seed(8432) -random.seed(8432) -np.random.seed(8432) -#7978026 - -class Null_Model(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x): - pass - - -def identity(x): - if isinstance(x, bytes): - x = x.decode('utf-8') - return x -def check_nan_inmodel(model, meta=''): - for name, param in model.named_parameters(): - if torch.isnan(param).any(): - print(f"nan detected in {name}", meta) - return True - print('no nan', meta) - return False - - -class WurstCore(TrainingCore, DataCore, WarpCore): - @dataclass(frozen=True) - class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config): - # TRAINING PARAMS - lr: float = EXPECTED_TRAIN - warmup_updates: int = EXPECTED_TRAIN - dtype: str = None - - # MODEL VERSION - model_version: str = EXPECTED # 3.6B or 1B - clip_image_model_name: str = 'openai/clip-vit-large-patch14' - clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' - - # CHECKPOINT PATHS - effnet_checkpoint_path: str = EXPECTED - previewer_checkpoint_path: str = EXPECTED - #trans_inr_ckpt: str = EXPECTED - generator_checkpoint_path: str = None - controlnet_checkpoint_path: str = EXPECTED - - # controlnet settings - controlnet_blocks: list = EXPECTED - controlnet_filter: str = EXPECTED - controlnet_filter_params: dict = None - controlnet_bottleneck_mode: str = None - - - # gdf customization - adaptive_loss_weight: str = None - - #module_filters: list = EXPECTED - #rank: int = EXPECTED - @dataclass(frozen=True) - class Data(Base): - dataset: Dataset = EXPECTED - dataloader: DataLoader = EXPECTED - iterator: any = EXPECTED - sampler: DistributedSampler = EXPECTED - - @dataclass(frozen=True) - class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models): - effnet: nn.Module = EXPECTED - previewer: nn.Module = EXPECTED - train_norm: nn.Module = EXPECTED - train_norm_ema: nn.Module = EXPECTED - controlnet: nn.Module = EXPECTED - - @dataclass(frozen=True) - class Schedulers(WarpCore.Schedulers): - generator: any = None - - @dataclass(frozen=True) - class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras): - gdf: GDF = EXPECTED - sampling_configs: dict = EXPECTED - effnet_preprocess: torchvision.transforms.Compose = EXPECTED - controlnet_filter: controlnet_filters.BaseFilter = EXPECTED - - info: TrainingCore.Info - config: Config - - def setup_extras_pre(self) -> Extras: - gdf = GDF( - schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]), - input_scaler=VPScaler(), target=EpsilonTarget(), - noise_cond=CosineTNoiseCond(), - loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(), - ) - sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20} - - if self.info.adaptive_loss is not None: - gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) - gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) - - effnet_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) - ]) - - clip_preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711) - ) - ]) - - if self.config.training: - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True), - SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2) - ]) - else: - transforms = None - controlnet_filter = getattr(controlnet_filters, self.config.controlnet_filter)( - self.device, - **(self.config.controlnet_filter_params if self.config.controlnet_filter_params is not None else {}) - ) - - return self.Extras( - gdf=gdf, - sampling_configs=sampling_configs, - transforms=transforms, - effnet_preprocess=effnet_preprocess, - clip_preprocess=clip_preprocess, - controlnet_filter=controlnet_filter - ) - def get_cnet(self, batch: dict, models: Models, extras: Extras, cnet_input=None, target_size=None, **kwargs): - images = batch['images'] - if target_size is not None: - images = Image.resize(images, target_size) - with torch.no_grad(): - if cnet_input is None: - cnet_input = extras.controlnet_filter(images, **kwargs) - if isinstance(cnet_input, tuple): - cnet_input, cnet_input_preview = cnet_input - else: - cnet_input_preview = cnet_input - cnet_input, cnet_input_preview = cnet_input.to(self.device), cnet_input_preview.to(self.device) - cnet = models.controlnet(cnet_input) - return cnet, cnet_input_preview - - def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, - eval_image_embeds=False, return_fields=None): - conditions = super().get_conditions( - batch, models, extras, is_eval, is_unconditional, - eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img'] - ) - return conditions - - def setup_models(self, extras: Extras) -> Models: # configure model - - - dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16 - - # EfficientNet encoderin - effnet = EfficientNetEncoder() - effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path) - effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict']) - effnet.eval().requires_grad_(False).to(self.device) - del effnet_checkpoint - - # Previewer - previewer = Previewer() - previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path) - previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict']) - previewer.eval().requires_grad_(False).to(self.device) - del previewer_checkpoint - - @contextmanager - def dummy_context(): - yield None - - loading_context = dummy_context if self.config.training else init_empty_weights - - # Diffusion models - with loading_context(): - generator_ema = None - if self.config.model_version == '3.6B': - generator = StageC() - if self.config.ema_start_iters is not None: # default setting - generator_ema = StageC() - elif self.config.model_version == '1B': - - generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - - if self.config.ema_start_iters is not None and self.config.training: - generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]]) - else: - raise ValueError(f"Unknown model version {self.config.model_version}") - - - - if loading_context is dummy_context: - generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path)) - else: - for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items(): - set_module_tensor_to_device(generator, param_name, "cpu", value=param) - - generator._init_extra_parameter() - - - - - generator = generator.to(torch.bfloat16).to(self.device) - - train_norm = nn.ModuleList() - - - cnt_norm = 0 - for mm in generator.modules(): - if isinstance(mm, GlobalResponseNorm): - - train_norm.append(Null_Model()) - cnt_norm += 1 - - - - - train_norm.append(generator.agg_net) - train_norm.append(generator.agg_net_up) - - - - - if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')): - sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu') - collect_sd = {} - for k, v in sdd.items(): - collect_sd[k[7:]] = v - train_norm.load_state_dict(collect_sd, strict=True) - - - train_norm.to(self.device).train().requires_grad_(True) - train_norm_ema = copy.deepcopy(train_norm) - train_norm_ema.to(self.device).eval().requires_grad_(False) - if generator_ema is not None: - - generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path)) - generator_ema._init_extra_parameter() - - pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors') - if os.path.exists(pretrained_pth): - print(pretrained_pth, 'exists') - generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu')) - - generator_ema.eval().requires_grad_(False) - - check_nan_inmodel(generator, 'generator') - - - - if self.config.use_fsdp and self.config.training: - train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True) - - - # CLIP encoders - tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name) - text_model = CLIPTextModelWithProjection.from_pretrained(self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device) - image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device) - - controlnet = ControlNet( - c_in=extras.controlnet_filter.num_channels(), - proj_blocks=self.config.controlnet_blocks, - bottleneck_mode=self.config.controlnet_bottleneck_mode - ) - controlnet = controlnet.to(dtype).to(self.device) - controlnet = self.load_model(controlnet, 'controlnet') - controlnet.backbone.eval().requires_grad_(True) - - - return self.Models( - effnet=effnet, previewer=previewer, train_norm = train_norm, - generator=generator, generator_ema=generator_ema, - tokenizer=tokenizer, text_model=text_model, image_model=image_model, - train_norm_ema=train_norm_ema, controlnet =controlnet - ) - - def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - -# - - params = [] - params += list(models.train_norm.module.parameters()) - - optimizer = optim.AdamW(params, lr=self.config.lr) - - return self.Optimizers(generator=optimizer) - - def ema_update(self, ema_model, source_model, beta): - for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()): - param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta) - - def sync_ema(self, ema_model): - print('sync ema', torch.distributed.get_world_size()) - for param in ema_model.parameters(): - torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM) - param.data /= torch.distributed.get_world_size() - def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers: - - - optimizer = optim.AdamW( - models.generator.up_blocks.parameters() , - lr=self.config.lr) - optimizer = self.load_optimizer(optimizer, 'generator_optim', - fsdp_model=models.generator if self.config.use_fsdp else None) - return self.Optimizers(generator=optimizer) - - def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers: - scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates) - scheduler.last_epoch = self.info.total_steps - return self.Schedulers(generator=scheduler) - - def setup_data(self, extras: Extras) -> WarpCore.Data: - # SETUP DATASET - dataset_path = self.config.webdataset_path - print('in line 96', dataset_path, type(dataset_path)) - - dataset = mydist_dataset(dataset_path, \ - torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \ - else extras.transforms) - - # SETUP DATALOADER - real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) - print('in line 119', self.process_id, real_batch_size) - sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=4, pin_memory=True, - collate_fn=identity if self.config.multi_aspect_ratio is not None else None, - sampler = sampler - ) - if self.is_main_node: - print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") - - if self.config.multi_aspect_ratio is not None: - aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] - dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32, - ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, - interpolate_nearest=False) # , use_smartcrop=True) - else: - - dataloader_iterator = iter(dataloader) - - return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler) - - - - - - def setup_ddp(self, experiment_id, single_gpu=False, rank=0): - - if not single_gpu: - local_rank = rank - process_id = rank - world_size = get_world_size() - - self.process_id = process_id - self.is_main_node = process_id == 0 - self.device = torch.device(local_rank) - self.world_size = world_size - - - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '41443' - torch.cuda.set_device(local_rank) - init_process_group( - backend="nccl", - rank=local_rank, - world_size=world_size, - # init_method=init_method, - ) - print(f"[GPU {process_id}] READY") - else: - self.is_main_node = rank == 0 - self.process_id = rank - self.device = torch.device('cuda:0') - self.world_size = 1 - print("Running in single thread, DDP not enabled.") - # Training loop -------------------------------- - def get_target_lr_size(self, ratio, std_size=24): - w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) - return (h * 32 , w * 32) - def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): - #batch = next(data.iterator) - batch = data - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - shape_lr = self.get_target_lr_size(ratio) - - with torch.no_grad(): - conditions = self.get_conditions(batch, models, extras) - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr) - - noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) - noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, ) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - - require_cond = True - - with torch.no_grad(): - _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions) - - - pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions) - loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) - - loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps - # - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - extras.gdf.loss_weight.update_buckets(logSNR, loss) - - return loss, loss_adjusted - - def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers): - - if update: - - torch.distributed.barrier() - loss_adjusted.backward() - - - grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0) - - optimizers_dict = optimizers.to_dict() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].step() - schedulers_dict = schedulers.to_dict() - for k in schedulers_dict: - if k != 'training': - schedulers_dict[k].step() - for k in optimizers_dict: - if k != 'training': - optimizers_dict[k].zero_grad(set_to_none=True) - self.info.total_steps += 1 - else: - #print('in line 457', loss_adjusted) - loss_adjusted.backward() - #torch.distributed.barrier() - grad_norm = torch.tensor(0.0).to(self.device) - - return grad_norm - - def models_to_save(self): - return ['generator', 'generator_ema', 'trans_inr', 'trans_inr_ema'] - - def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor: - - images = batch['images'].to(self.device) - if target_size is not None: - images = F.interpolate(images, target_size) - #images = apply_degradations(images) - return models.effnet(extras.effnet_preprocess(images)) - - def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: - return models.previewer(latents) - - def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ): - # Temporary setup, will be overriden by setup_ddp if required - # self.device = device - # self.process_id = 0 - # self.is_main_node = True - # self.world_size = 1 - # ---- - # self.world_size = world_size - # self.process_id = rank - # self.device=device - self.is_main_node = (rank == 0) - self.config: self.Config = self.setup_config(config_file_path, config_dict, training) - self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank) - self.info: self.Info = self.setup_info() - print('in line 292', self.config.experiment_id, rank, world_size <= 1) - p = [i for i in range( 2 * 768 // 32)] - p = [num / sum(p) for num in p] - self.rand_pro = p - self.res_list = [o for o in range(800, 2336, 32)] - - #[32, 40, 48] - #in line 292 stage_c_3b_finetuning False - - def __call__(self, single_gpu=False): - # this will change the device to the CUDA rank - #self.setup_wandb() - if self.config.allow_tf32: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - if self.is_main_node: - print() - print("**STARTIG JOB WITH CONFIG:**") - print(yaml.dump(self.config.to_dict(), default_flow_style=False)) - print("------------------------------------") - print() - print("**INFO:**") - print(yaml.dump(vars(self.info), default_flow_style=False)) - print("------------------------------------") - print() - print('in line 308', self.is_main_node, self.is_main_node, self.process_id, self.device ) - # SETUP STUFF - extras = self.setup_extras_pre() - assert extras is not None, "setup_extras_pre() must return a DTO" - - - - data = self.setup_data(extras) - assert data is not None, "setup_data() must return a DTO" - if self.is_main_node: - print("**DATA:**") - print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - models = self.setup_models(extras) - assert models is not None, "setup_models() must return a DTO" - if self.is_main_node: - print("**MODELS:**") - print(yaml.dump({ - k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() - }, default_flow_style=False)) - print("------------------------------------") - print() - - - - optimizers = self.setup_optimizers(extras, models) - assert optimizers is not None, "setup_optimizers() must return a DTO" - if self.is_main_node: - print("**OPTIMIZERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - schedulers = self.setup_schedulers(extras, models, optimizers) - assert schedulers is not None, "setup_schedulers() must return a DTO" - if self.is_main_node: - print("**SCHEDULERS:**") - print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - - post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) - assert post_extras is not None, "setup_extras_post() must return a DTO" - extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) - if self.is_main_node: - print("**EXTRAS:**") - print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) - print("------------------------------------") - print() - # ------- - - # TRAIN - if self.is_main_node: - print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) - - if single_gpu is False: - barrier() - destroy_process_group() - if self.is_main_node: - print() - print("------------------------------------") - print() - print("**TRAINING COMPLETE**") - if self.config.wandb_project is not None: - wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") - - - def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers, - schedulers: WarpCore.Schedulers): - start_iter = self.info.iter + 1 - max_iters = self.config.updates * self.config.grad_accum_steps - if self.is_main_node: - print(f"STARTING AT STEP: {start_iter}/{max_iters}") - - - if self.is_main_node: - create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') - if 'generator' in self.models_to_save(): - models.generator.train() - #initial_params = {name: param.clone() for name, param in models.train_norm.named_parameters()} - iter_cnt = 0 - epoch_cnt = 0 - models.train_norm.train() - while True: - epoch_cnt += 1 - if self.world_size > 1: - print('sampler set epoch', epoch_cnt) - data.sampler.set_epoch(epoch_cnt) - for ggg in range(len(data.dataloader)): - iter_cnt += 1 - # FORWARD PASS - #print('in line 414 before forward', iter_cnt, batch['captions'][0], self.process_id) - #loss, loss_adjusted, loss_extra = self.forward_pass(batch, extras, models) - loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models) - - #print('in line 416', loss, iter_cnt) - # # BACKWARD PASS - - grad_norm = self.backward_pass( - iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted, - models, optimizers, schedulers - ) - - - - self.info.iter = iter_cnt - - # UPDATE EMA - if iter_cnt % self.config.ema_iters == 0: - - with torch.no_grad(): - print('in line 890 ema update', self.config.ema_iters, iter_cnt) - self.ema_update(models.train_norm_ema, models.train_norm, self.config.ema_beta) - #generator.module.agg_net. - #self.ema_update(models.generator_ema.agg_net, models.generator.module.agg_net, self.config.ema_beta) - #self.ema_update(models.generator_ema.agg_net_up, models.generator.module.agg_net_up, self.config.ema_beta) - - # UPDATE LOSS METRICS - self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 - - #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss) - if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): - print(f"gggg NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - if self.is_main_node: - logs = { - 'loss': self.info.ema_loss, - 'backward_loss': loss_adjusted.mean().item(), - #'raw_extra_loss': loss_extra.mean().item(), - 'ema_loss': self.info.ema_loss, - 'raw_ori_loss': loss.mean().item(), - #'raw_rec_loss': loss_rec.mean().item(), - #'raw_lr_loss': loss_lr.mean().item(), - #'reg_loss':loss_reg.item(), - 'grad_norm': grad_norm.item(), - 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, - 'total_steps': self.info.total_steps, - } - if iter_cnt % (self.config.save_every) == 0: - - print(iter_cnt, max_iters, logs, epoch_cnt, ) - #pbar.set_postfix(logs) - - - #if iter_cnt % 10 == 0: - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters: - #if True: - # SAVE AND CHECKPOINT STUFF - if np.isnan(loss.mean().item()): - if self.is_main_node and self.config.wandb_project is not None: - print(f"NaN value encountered in training run {self.info.wandb_run_id}", \ - f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}") - - else: - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - self.info.adaptive_loss = { - 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), - 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), - } - #self.save_checkpoints(models, optimizers) - - #torch.save(models.trans_inr.module.state_dict(), \ - #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr.safetensors') - #torch.save(models.trans_inr_ema.state_dict(), \ - #f'{self.config.output_path}/{self.config.experiment_id}/trans_inr_ema.safetensors') - - - if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0: - print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps ) - torch.save(models.train_norm.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors') - - #self.sync_ema(models.train_norm_ema) - torch.save(models.train_norm_ema.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_norm_ema.safetensors') - #if self.is_main_node and iter_cnt % (4 * self.config.save_every * self.config.grad_accum_steps) == 0: - torch.save(models.train_norm.state_dict(), \ - f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors') - - - if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters: - - if self.is_main_node: - #check_nan_inmodel(models.generator, 'generator') - #check_nan_inmodel(models.generator_ema, 'generator_ema') - self.sample(models, data, extras) - if False: - param_changes = {name: (param - initial_params[name]).norm().item() for name, param in models.train_norm.named_parameters()} - threshold = sorted(param_changes.values(), reverse=True)[int(len(param_changes) * 0.1)] # top 10% - important_params = [name for name, change in param_changes.items() if change > threshold] - print(important_params, threshold, len(param_changes), self.process_id) - json.dump(important_params, open(f'{self.config.output_path}/{self.config.experiment_id}/param.json', 'w'), indent=4) - - - if self.info.iter >= max_iters: - break - - def sample(self, models: Models, data: WarpCore.Data, extras: Extras): - - #if 'generator' in self.models_to_save(): - models.generator.eval() - models.train_norm.eval() - with torch.no_grad(): - batch = next(data.iterator) - ratio = batch['images'].shape[-2] / batch['images'].shape[-1] - #batch['images'] = batch['images'].to(torch.float16) - shape_lr = self.get_target_lr_size(ratio) - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - cnet, cnet_input = self.get_cnet(batch, models, extras) - conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} - - latents = self.encode_latents(batch, models, extras) - latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr) - - if self.is_main_node: - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - #print('in line 366 on v100 switch to tf16') - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, models.trans_inr, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - - - #else: - sampled_ema = sampled - sampled_ema_lr = sampled_lr - - - if self.is_main_node: - print('sampling results', latents.shape, latents_lr.shape, ) - noised_images = torch.cat( - [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0) - - sampled_images = torch.cat( - [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0) - sampled_images_ema = torch.cat( - [self.decode_latents(sampled_ema[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema))], - dim=0) - - noised_images_lr = torch.cat( - [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0) - - sampled_images_lr = torch.cat( - [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0) - sampled_images_ema_lr = torch.cat( - [self.decode_latents(sampled_ema_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_ema_lr))], - dim=0) - - images = batch['images'] - if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): - images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') - images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic') - - collage_img = torch.cat([ - torch.cat([i for i in images.cpu()], dim=-1), - torch.cat([i for i in noised_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images.cpu()], dim=-1), - torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), - ], dim=-2) - - collage_img_lr = torch.cat([ - torch.cat([i for i in images_lr.cpu()], dim=-1), - torch.cat([i for i in noised_images_lr.cpu()], dim=-1), - torch.cat([i for i in sampled_images_lr.cpu()], dim=-1), - torch.cat([i for i in sampled_images_ema_lr.cpu()], dim=-1), - ], dim=-2) - - torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') - torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg') - #torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') - - captions = batch['captions'] - if self.config.wandb_project is not None: - log_data = [ - [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ - wandb.Image(images[i])] for i in range(len(images))] - log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) - wandb.log({"Log": log_table}) - - if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): - plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) - plt.ylabel('Raw Loss') - plt.ylabel('LogSNR') - wandb.log({"Loss/LogSRN": plt}) - - #if 'generator' in self.models_to_save(): - models.generator.train() - models.train_norm.train() - print('finishe sampling in line 901') - - - - def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False): - - #if 'generator' in self.models_to_save(): - models.generator.eval() - models.trans_inr.eval() - models.controlnet.eval() - with torch.no_grad(): - - if self.is_main_node: - conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds) - unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) - cnet, cnet_input = self.get_cnet(batch, models, extras, target_size = lr_shape) - conditions, unconditions = {**conditions, 'cnet': cnet}, {**unconditions, 'cnet': cnet} - - #print('in line 885', self.is_main_node) - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - #print('in line 366 on v100 switch to tf16') - *_, (sampled, _, _, sampled_lr) = extras.gdf.sample( - models.generator, models.trans_inr, conditions, - hr_shape, lr_shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - if models.generator_ema is not None: - - *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample( - models.generator_ema, models.trans_inr_ema, conditions, - latents.shape, latents_lr.shape, - unconditions, device=self.device, **extras.sampling_configs - ) - - else: - sampled_ema = sampled - sampled_ema_lr = sampled_lr - #x0, x, epsilon, x0_lr, x_lr, pred_lr) - #sampled, _ = models.trans_inr(sampled, None, sampled) - #sampled_lr, _ = models.trans_inr(sampled, None, sampled_lr) - - return sampled, sampled_lr -def main_worker(rank, cfg): - print("Launching Script in main worker") - print('in line 467', rank) - warpcore = WurstCore( - config_file_path=cfg, rank=rank, world_size = get_world_size() - ) - # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - - # RUN TRAINING - warpcore(get_world_size()==1) - -if __name__ == '__main__': - print('launch multi process') - # os.environ["OMP_NUM_THREADS"] = "1" - # os.environ["MKL_NUM_THREADS"] = "1" - #dist.init_process_group(backend="nccl") - #torch.backends.cudnn.benchmark = True -#train/train_c_my.py - #mp.set_sharing_strategy('file_system') - print('in line 481', sys.argv[1] if len(sys.argv) > 1 else None) - print('in line 481',get_master_ip(), get_world_size() ) - print('in line 484', get_world_size()) - if get_master_ip() == "127.0.0.1": - # manually launch distributed processes - mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, )) - else: - main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )