alexnasa commited on
Commit
15fa7b9
·
verified ·
1 Parent(s): f622070

Upload 6 files

Browse files
Files changed (6) hide show
  1. cog.yaml +24 -0
  2. predict.py +202 -0
  3. requirements.txt +10 -0
  4. test_seesr.py +271 -0
  5. test_seesr_turbo.py +271 -0
  6. train_seesr.py +1093 -0
cog.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3
+ build:
4
+ gpu: true
5
+ python_version: "3.8"
6
+ python_packages:
7
+ - "accelerate==0.25.0"
8
+ - "diffusers==0.21.0"
9
+ - "torch==2.0.1"
10
+ - "pytorch_lightning==2.1.3"
11
+ - "transformers==4.25.0"
12
+ - "xformers"
13
+ - "loralib==0.1.2"
14
+ - "fairscale==0.4.13"
15
+ - "opencv-python==4.9.0.80"
16
+ - "chardet==5.2.0"
17
+ - "einops==0.7.0"
18
+ - "scipy==1.10.1"
19
+ - "timm==0.9.12"
20
+
21
+ run:
22
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.3.1/pget" && chmod +x /usr/local/bin/pget
23
+ # predict.py defines how predictions are run on your model
24
+ predict: "predict.py:Predictor"
predict.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+ from cog import BasePredictor, Input, Path
4
+ import os
5
+ import time
6
+ import subprocess
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from pytorch_lightning import seed_everything
15
+ from diffusers import AutoencoderKL, DDPMScheduler
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
18
+
19
+ from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline
20
+
21
+ from utils.wavelet_color_fix import wavelet_color_fix
22
+
23
+ from ram.models.ram_lora import ram
24
+ from ram import inference_ram as inference
25
+ from torchvision import transforms
26
+ from models.controlnet import ControlNetModel
27
+ from models.unet_2d_condition import UNet2DConditionModel
28
+
29
+ MODEL_URL = "https://weights.replicate.delivery/default/stabilityai/sd-2-1-base.tar"
30
+
31
+ tensor_transforms = transforms.Compose([
32
+ transforms.ToTensor(),
33
+ ])
34
+ ram_transforms = transforms.Compose([
35
+ transforms.Resize((384, 384)),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
+ ])
38
+ device = "cuda"
39
+
40
+ def download_weights(url, dest):
41
+ start = time.time()
42
+ print("downloading url: ", url)
43
+ print("downloading to: ", dest)
44
+ subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
45
+ print("downloading took: ", time.time() - start)
46
+
47
+ class Predictor(BasePredictor):
48
+ def setup(self) -> None:
49
+ """Load the model into memory to make running multiple predictions efficient"""
50
+ # Load scheduler, tokenizer and models.
51
+ pretrained_model_path = 'preset/models/stable-diffusion-2-1-base'
52
+ seesr_model_path = 'preset/models/seesr'
53
+
54
+ # Download SD-2-1 weights
55
+ if not os.path.exists(pretrained_model_path):
56
+ download_weights(MODEL_URL, pretrained_model_path)
57
+
58
+ scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
59
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
60
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
61
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
62
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor")
63
+ unet = UNet2DConditionModel.from_pretrained(seesr_model_path, subfolder="unet")
64
+ controlnet = ControlNetModel.from_pretrained(seesr_model_path, subfolder="controlnet")
65
+
66
+ # Freeze vae and text_encoder
67
+ vae.requires_grad_(False)
68
+ text_encoder.requires_grad_(False)
69
+ unet.requires_grad_(False)
70
+ controlnet.requires_grad_(False)
71
+
72
+ if is_xformers_available():
73
+ unet.enable_xformers_memory_efficient_attention()
74
+ controlnet.enable_xformers_memory_efficient_attention()
75
+ else:
76
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
77
+
78
+ # Get the validation pipeline
79
+ validation_pipeline = StableDiffusionControlNetPipeline(
80
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
81
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
82
+ )
83
+ validation_pipeline._init_tiled_vae(encoder_tile_size=1024,decoder_tile_size=224)
84
+ self.validation_pipeline = validation_pipeline
85
+ weight_dtype = torch.float16
86
+
87
+ # Move text_encode and vae to gpu and cast to weight_dtype
88
+ text_encoder.to(device, dtype=weight_dtype)
89
+ vae.to(device, dtype=weight_dtype)
90
+ unet.to(device, dtype=weight_dtype)
91
+ controlnet.to(device, dtype=weight_dtype)
92
+ tag_model = ram(pretrained='preset/models/ram_swin_large_14m.pth',
93
+ pretrained_condition='preset/models/DAPE.pth',
94
+ image_size=384,
95
+ vit='swin_l')
96
+ tag_model.eval()
97
+ self.tag_model = tag_model.to(device, dtype=weight_dtype)
98
+
99
+
100
+ # @torch.no_grad()
101
+ def process(
102
+ self,
103
+ input_image: Image.Image,
104
+ user_prompt: str,
105
+ positive_prompt: str,
106
+ negative_prompt: str,
107
+ num_inference_steps: int,
108
+ scale_factor: int,
109
+ cfg_scale: float,
110
+ seed: int,
111
+ latent_tiled_size: int,
112
+ latent_tiled_overlap: int,
113
+ sample_times: int
114
+ ) -> List[np.ndarray]:
115
+ process_size = 512
116
+ resize_preproc = transforms.Compose([
117
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
118
+ ])
119
+
120
+ seed_everything(seed)
121
+ generator = torch.Generator(device=device)
122
+
123
+ validation_prompt = ""
124
+ lq = tensor_transforms(input_image).unsqueeze(0).to(device).half()
125
+ lq = ram_transforms(lq)
126
+ res = inference(lq, self.tag_model)
127
+ ram_encoder_hidden_states = self.tag_model.generate_image_embeds(lq)
128
+ validation_prompt = f"{res[0]}, {positive_prompt},"
129
+ validation_prompt = validation_prompt if user_prompt=='' else f"{user_prompt}, {validation_prompt}"
130
+
131
+ ori_width, ori_height = input_image.size
132
+ resize_flag = False
133
+
134
+ rscale = scale_factor
135
+ input_image = input_image.resize((int(input_image.size[0] * rscale), int(input_image.size[1] * rscale)))
136
+
137
+ if min(input_image.size) < process_size:
138
+ input_image = resize_preproc(input_image)
139
+
140
+ input_image = input_image.resize((input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8))
141
+ width, height = input_image.size
142
+ resize_flag = True
143
+
144
+ images = []
145
+ for _ in range(sample_times):
146
+ try:
147
+ with torch.autocast("cuda"):
148
+ image = self.validation_pipeline(
149
+ validation_prompt, input_image, negative_prompt=negative_prompt,
150
+ num_inference_steps=num_inference_steps, generator=generator,
151
+ height=height, width=width,
152
+ guidance_scale=cfg_scale, conditioning_scale=1,
153
+ start_point='lr', start_steps=999,ram_encoder_hidden_states=ram_encoder_hidden_states,
154
+ latent_tiled_size=latent_tiled_size, latent_tiled_overlap=latent_tiled_overlap
155
+ ).images[0]
156
+
157
+ if True: # alpha<1.0:
158
+ image = wavelet_color_fix(image, input_image)
159
+
160
+ if resize_flag:
161
+ image = image.resize((ori_width * rscale, ori_height * rscale))
162
+ except Exception as e:
163
+ print(e)
164
+ image = Image.new(mode="RGB", size=(512, 512))
165
+ images.append(np.array(image))
166
+ return images
167
+
168
+
169
+ @torch.inference_mode()
170
+ def predict(
171
+ self,
172
+ image: Path = Input(description="Input image"),
173
+ user_prompt: str = Input(description="Prompt to condition on", default=""),
174
+ positive_prompt: str = Input(description="Prompt to add", default="clean, high-resolution, 8k"),
175
+ negative_prompt: str = Input(description="Prompt to remove", default="dotted, noise, blur, lowres, smooth"),
176
+ cfg_scale: float = Input(description="Guidance scale, set value to >1 to use", default=5.5, ge=0.1, le=10.0),
177
+ num_inference_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=100),
178
+ sample_times: int = Input(description="Number of samples to generate", default=1, ge=1, le=10),
179
+ latent_tiled_size: int = Input(description="Size of latent tiles", default=320, ge=128, le=480),
180
+ latent_tiled_overlap: int = Input(description="Overlap of latent tiles", default=4, ge=4, le=16),
181
+ scale_factor: int = Input(description="Scale factor", default=4),
182
+ seed: int = Input(description="Seed", default=231, ge=0, le=2147483647),
183
+ ) -> List[Path]:
184
+ """Run a single prediction on the model"""
185
+ pil_image = Image.open(image).convert("RGB")
186
+ imgs = self.process(
187
+ pil_image, user_prompt, positive_prompt, negative_prompt, num_inference_steps,
188
+ scale_factor, cfg_scale, seed, latent_tiled_size, latent_tiled_overlap, sample_times)
189
+
190
+ # Clear output folder
191
+ os.system("rm -rf /tmp/output")
192
+ # Create output folder
193
+ os.system("mkdir /tmp/output")
194
+ # Save images to output folder
195
+ output_paths = []
196
+ for i, img in enumerate(imgs):
197
+ img = Image.fromarray(img)
198
+ output_path = f"/tmp/output/{i}.png"
199
+ img.save(output_path)
200
+ output_paths.append(Path(output_path))
201
+
202
+ return output_paths
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.21.0
2
+ torch==2.0.1
3
+ pytorch_lightning
4
+ accelerate
5
+ transformers==4.25.0
6
+ xformers
7
+ loralib
8
+ fairscale
9
+ pydantic==1.10.11
10
+ gradio==3.24.0
test_seesr.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution
3
+ * Modified from diffusers by Rongyuan Wu
4
+ * 24/12/2023
5
+ '''
6
+ import os
7
+ import sys
8
+ sys.path.append(os.getcwd())
9
+ import cv2
10
+ import glob
11
+ import argparse
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ import torch
16
+ import torch.utils.checkpoint
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import AutoencoderKL, DDPMScheduler
22
+ from diffusers.utils import check_min_version
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
25
+
26
+ from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline
27
+ from utils.misc import load_dreambooth_lora
28
+ from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
29
+
30
+ from ram.models.ram_lora import ram
31
+ from ram import inference_ram as inference
32
+ from ram import get_transform
33
+
34
+ from typing import Mapping, Any
35
+ from torchvision import transforms
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torchvision import transforms
39
+
40
+ logger = get_logger(__name__, log_level="INFO")
41
+
42
+
43
+ tensor_transforms = transforms.Compose([
44
+ transforms.ToTensor(),
45
+ ])
46
+
47
+ ram_transforms = transforms.Compose([
48
+ transforms.Resize((384, 384)),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ])
51
+ def load_state_dict_diffbirSwinIR(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None:
52
+ state_dict = state_dict.get("state_dict", state_dict)
53
+
54
+ is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.")
55
+ is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.")
56
+
57
+ if (
58
+ is_model_key_starts_with_module and
59
+ (not is_state_dict_key_starts_with_module)
60
+ ):
61
+ state_dict = {f"module.{key}": value for key, value in state_dict.items()}
62
+ if (
63
+ (not is_model_key_starts_with_module) and
64
+ is_state_dict_key_starts_with_module
65
+ ):
66
+ state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
67
+
68
+ model.load_state_dict(state_dict, strict=strict)
69
+
70
+
71
+ def load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention):
72
+
73
+ from models.controlnet import ControlNetModel
74
+ from models.unet_2d_condition import UNet2DConditionModel
75
+
76
+ # Load scheduler, tokenizer and models.
77
+
78
+ scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler")
79
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
80
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
81
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
82
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{args.pretrained_model_path}/feature_extractor")
83
+ unet = UNet2DConditionModel.from_pretrained(args.seesr_model_path, subfolder="unet")
84
+ controlnet = ControlNetModel.from_pretrained(args.seesr_model_path, subfolder="controlnet")
85
+
86
+ # Freeze vae and text_encoder
87
+ vae.requires_grad_(False)
88
+ text_encoder.requires_grad_(False)
89
+ unet.requires_grad_(False)
90
+ controlnet.requires_grad_(False)
91
+
92
+ if enable_xformers_memory_efficient_attention:
93
+ if is_xformers_available():
94
+ unet.enable_xformers_memory_efficient_attention()
95
+ controlnet.enable_xformers_memory_efficient_attention()
96
+ else:
97
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
98
+
99
+ # Get the validation pipeline
100
+ validation_pipeline = StableDiffusionControlNetPipeline(
101
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
102
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
103
+ )
104
+
105
+ validation_pipeline._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
106
+
107
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
108
+ # as these models are only used for inference, keeping weights in full precision is not required.
109
+ weight_dtype = torch.float32
110
+ if accelerator.mixed_precision == "fp16":
111
+ weight_dtype = torch.float16
112
+ elif accelerator.mixed_precision == "bf16":
113
+ weight_dtype = torch.bfloat16
114
+
115
+ # Move text_encode and vae to gpu and cast to weight_dtype
116
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
117
+ vae.to(accelerator.device, dtype=weight_dtype)
118
+ unet.to(accelerator.device, dtype=weight_dtype)
119
+ controlnet.to(accelerator.device, dtype=weight_dtype)
120
+
121
+ return validation_pipeline
122
+
123
+ def load_tag_model(args, device='cuda'):
124
+
125
+ model = ram(pretrained='preset/models/ram_swin_large_14m.pth',
126
+ pretrained_condition=args.ram_ft_path,
127
+ image_size=384,
128
+ vit='swin_l')
129
+ model.eval()
130
+ model.to(device)
131
+
132
+ return model
133
+
134
+ def get_validation_prompt(args, image, model, device='cuda'):
135
+ validation_prompt = ""
136
+
137
+ lq = tensor_transforms(image).unsqueeze(0).to(device)
138
+ lq = ram_transforms(lq)
139
+ res = inference(lq, model)
140
+ ram_encoder_hidden_states = model.generate_image_embeds(lq)
141
+
142
+ validation_prompt = f"{res[0]}, {args.prompt},"
143
+
144
+ return validation_prompt, ram_encoder_hidden_states
145
+
146
+ def main(args, enable_xformers_memory_efficient_attention=True,):
147
+ txt_path = os.path.join(args.output_dir, 'txt')
148
+ os.makedirs(txt_path, exist_ok=True)
149
+
150
+ accelerator = Accelerator(
151
+ mixed_precision=args.mixed_precision,
152
+ )
153
+
154
+ # If passed along, set the training seed now.
155
+ if args.seed is not None:
156
+ set_seed(args.seed)
157
+
158
+ # Handle the output folder creation
159
+ if accelerator.is_main_process:
160
+ os.makedirs(args.output_dir, exist_ok=True)
161
+
162
+ # We need to initialize the trackers we use, and also store our configuration.
163
+ # The trackers initializes automatically on the main process.
164
+ if accelerator.is_main_process:
165
+ accelerator.init_trackers("SeeSR")
166
+
167
+ pipeline = load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention)
168
+ model = load_tag_model(args, accelerator.device)
169
+
170
+ if accelerator.is_main_process:
171
+ generator = torch.Generator(device=accelerator.device)
172
+ if args.seed is not None:
173
+ generator.manual_seed(args.seed)
174
+
175
+ if os.path.isdir(args.image_path):
176
+ image_names = sorted(glob.glob(f'{args.image_path}/*.*'))
177
+ else:
178
+ image_names = [args.image_path]
179
+
180
+ for image_idx, image_name in enumerate(image_names[:]):
181
+ print(f'================== process {image_idx} imgs... ===================')
182
+ validation_image = Image.open(image_name).convert("RGB")
183
+
184
+ validation_prompt, ram_encoder_hidden_states = get_validation_prompt(args, validation_image, model)
185
+ validation_prompt += args.added_prompt # clean, extremely detailed, best quality, sharp, clean
186
+ negative_prompt = args.negative_prompt #dirty, messy, low quality, frames, deformed,
187
+
188
+ if args.save_prompts:
189
+ txt_save_path = f"{txt_path}/{os.path.basename(image_name).split('.')[0]}.txt"
190
+ file = open(txt_save_path, "w")
191
+ file.write(validation_prompt)
192
+ file.close()
193
+ print(f'{validation_prompt}')
194
+
195
+ ori_width, ori_height = validation_image.size
196
+ resize_flag = False
197
+ rscale = args.upscale
198
+ if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
199
+ scale = (args.process_size//rscale)/min(ori_width, ori_height)
200
+ tmp_image = validation_image.resize((int(scale*ori_width), int(scale*ori_height)))
201
+
202
+ validation_image = tmp_image
203
+ resize_flag = True
204
+
205
+ validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
206
+ validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
207
+ width, height = validation_image.size
208
+ resize_flag = True #
209
+
210
+ print(f'input size: {height}x{width}')
211
+
212
+ for sample_idx in range(args.sample_times):
213
+ os.makedirs(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True)
214
+
215
+ for sample_idx in range(args.sample_times):
216
+ with torch.autocast("cuda"):
217
+ image = pipeline(
218
+ validation_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, height=height, width=width,
219
+ guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
220
+ start_point=args.start_point, ram_encoder_hidden_states=ram_encoder_hidden_states,
221
+ latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
222
+ args=args,
223
+ ).images[0]
224
+
225
+ if args.align_method == 'nofix':
226
+ image = image
227
+ else:
228
+ if args.align_method == 'wavelet':
229
+ image = wavelet_color_fix(image, validation_image)
230
+ elif args.align_method == 'adain':
231
+ image = adain_color_fix(image, validation_image)
232
+
233
+ if resize_flag:
234
+ image = image.resize((ori_width*rscale, ori_height*rscale))
235
+
236
+ name, ext = os.path.splitext(os.path.basename(image_name))
237
+
238
+ image.save(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/{name}.png')
239
+
240
+ if __name__ == "__main__":
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--seesr_model_path", type=str, default=None)
243
+ parser.add_argument("--ram_ft_path", type=str, default=None)
244
+ parser.add_argument("--pretrained_model_path", type=str, default=None)
245
+ parser.add_argument("--prompt", type=str, default="") # user can add self-prompt to improve the results
246
+ parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k")
247
+ parser.add_argument("--negative_prompt", type=str, default="dotted, noise, blur, lowres, smooth")
248
+ parser.add_argument("--image_path", type=str, default=None)
249
+ parser.add_argument("--output_dir", type=str, default=None)
250
+ parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16
251
+ parser.add_argument("--guidance_scale", type=float, default=5.5)
252
+ parser.add_argument("--conditioning_scale", type=float, default=1.0)
253
+ parser.add_argument("--blending_alpha", type=float, default=1.0)
254
+ parser.add_argument("--num_inference_steps", type=int, default=50)
255
+ parser.add_argument("--process_size", type=int, default=512)
256
+ parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) # latent size, for 24G
257
+ parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) # image size, for 13G
258
+ parser.add_argument("--latent_tiled_size", type=int, default=96)
259
+ parser.add_argument("--latent_tiled_overlap", type=int, default=32)
260
+ parser.add_argument("--upscale", type=int, default=4)
261
+ parser.add_argument("--seed", type=int, default=None)
262
+ parser.add_argument("--sample_times", type=int, default=1)
263
+ parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
264
+ parser.add_argument("--start_steps", type=int, default=999) # defaults set to 999.
265
+ parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr') # LR Embedding Strategy, choose 'lr latent + 999 steps noise' as diffusion start point.
266
+ parser.add_argument("--save_prompts", action='store_true')
267
+ args = parser.parse_args()
268
+ main(args)
269
+
270
+
271
+
test_seesr_turbo.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution
3
+ * Modified from diffusers by Rongyuan Wu
4
+ * 24/12/2023
5
+ '''
6
+ import os
7
+ import sys
8
+ sys.path.append(os.getcwd())
9
+ import cv2
10
+ import glob
11
+ import argparse
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ import torch
16
+ import torch.utils.checkpoint
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import AutoencoderKL, DDPMScheduler
22
+ from diffusers.utils import check_min_version
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
25
+
26
+ from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline
27
+ from utils.misc import load_dreambooth_lora
28
+ from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
29
+
30
+ from ram.models.ram_lora import ram
31
+ from ram import inference_ram as inference
32
+ from ram import get_transform
33
+
34
+ from typing import Mapping, Any
35
+ from torchvision import transforms
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torchvision import transforms
39
+
40
+ logger = get_logger(__name__, log_level="INFO")
41
+
42
+
43
+ tensor_transforms = transforms.Compose([
44
+ transforms.ToTensor(),
45
+ ])
46
+
47
+ ram_transforms = transforms.Compose([
48
+ transforms.Resize((384, 384)),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+ ])
51
+ def load_state_dict_diffbirSwinIR(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None:
52
+ state_dict = state_dict.get("state_dict", state_dict)
53
+
54
+ is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.")
55
+ is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.")
56
+
57
+ if (
58
+ is_model_key_starts_with_module and
59
+ (not is_state_dict_key_starts_with_module)
60
+ ):
61
+ state_dict = {f"module.{key}": value for key, value in state_dict.items()}
62
+ if (
63
+ (not is_model_key_starts_with_module) and
64
+ is_state_dict_key_starts_with_module
65
+ ):
66
+ state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
67
+
68
+ model.load_state_dict(state_dict, strict=strict)
69
+
70
+
71
+ def load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention):
72
+
73
+ from models.controlnet import ControlNetModel
74
+ from models.unet_2d_condition import UNet2DConditionModel
75
+
76
+ # Load scheduler, tokenizer and models.
77
+
78
+ scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler")
79
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
80
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
81
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
82
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{args.pretrained_model_path}/feature_extractor")
83
+ unet = UNet2DConditionModel.from_pretrained_orig(args.pretrained_model_path, args.seesr_model_path, subfolder="unet", use_image_cross_attention=True)
84
+ controlnet = ControlNetModel.from_pretrained(args.seesr_model_path, subfolder="controlnet")
85
+
86
+ # Freeze vae and text_encoder
87
+ vae.requires_grad_(False)
88
+ text_encoder.requires_grad_(False)
89
+ unet.requires_grad_(False)
90
+ controlnet.requires_grad_(False)
91
+
92
+ if enable_xformers_memory_efficient_attention:
93
+ if is_xformers_available():
94
+ unet.enable_xformers_memory_efficient_attention()
95
+ controlnet.enable_xformers_memory_efficient_attention()
96
+ else:
97
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
98
+
99
+ # Get the validation pipeline
100
+ validation_pipeline = StableDiffusionControlNetPipeline(
101
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
102
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
103
+ )
104
+
105
+ validation_pipeline._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
106
+
107
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
108
+ # as these models are only used for inference, keeping weights in full precision is not required.
109
+ weight_dtype = torch.float32
110
+ if accelerator.mixed_precision == "fp16":
111
+ weight_dtype = torch.float16
112
+ elif accelerator.mixed_precision == "bf16":
113
+ weight_dtype = torch.bfloat16
114
+
115
+ # Move text_encode and vae to gpu and cast to weight_dtype
116
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
117
+ vae.to(accelerator.device, dtype=weight_dtype)
118
+ unet.to(accelerator.device, dtype=weight_dtype)
119
+ controlnet.to(accelerator.device, dtype=weight_dtype)
120
+
121
+ return validation_pipeline
122
+
123
+ def load_tag_model(args, device='cuda'):
124
+
125
+ model = ram(pretrained='preset/models/ram_swin_large_14m.pth',
126
+ pretrained_condition=args.ram_ft_path,
127
+ image_size=384,
128
+ vit='swin_l')
129
+ model.eval()
130
+ model.to(device)
131
+
132
+ return model
133
+
134
+ def get_validation_prompt(args, image, model, device='cuda'):
135
+ validation_prompt = ""
136
+
137
+ lq = tensor_transforms(image).unsqueeze(0).to(device)
138
+ lq = ram_transforms(lq)
139
+ res = inference(lq, model)
140
+ ram_encoder_hidden_states = model.generate_image_embeds(lq)
141
+
142
+ validation_prompt = f"{res[0]}, {args.prompt},"
143
+
144
+ return validation_prompt, ram_encoder_hidden_states
145
+
146
+ def main(args, enable_xformers_memory_efficient_attention=True,):
147
+ txt_path = os.path.join(args.output_dir, 'txt')
148
+ os.makedirs(txt_path, exist_ok=True)
149
+
150
+ accelerator = Accelerator(
151
+ mixed_precision=args.mixed_precision,
152
+ )
153
+
154
+ # If passed along, set the training seed now.
155
+ if args.seed is not None:
156
+ set_seed(args.seed)
157
+
158
+ # Handle the output folder creation
159
+ if accelerator.is_main_process:
160
+ os.makedirs(args.output_dir, exist_ok=True)
161
+
162
+ # We need to initialize the trackers we use, and also store our configuration.
163
+ # The trackers initializes automatically on the main process.
164
+ if accelerator.is_main_process:
165
+ accelerator.init_trackers("SeeSR")
166
+
167
+ pipeline = load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention)
168
+ model = load_tag_model(args, accelerator.device)
169
+
170
+ if accelerator.is_main_process:
171
+ generator = torch.Generator(device=accelerator.device)
172
+ if args.seed is not None:
173
+ generator.manual_seed(args.seed)
174
+
175
+ if os.path.isdir(args.image_path):
176
+ image_names = sorted(glob.glob(f'{args.image_path}/*.*'))
177
+ else:
178
+ image_names = [args.image_path]
179
+
180
+ for image_idx, image_name in enumerate(image_names[:]):
181
+ print(f'================== process {image_idx} imgs... ===================')
182
+ validation_image = Image.open(image_name).convert("RGB")
183
+
184
+ validation_prompt, ram_encoder_hidden_states = get_validation_prompt(args, validation_image, model)
185
+ validation_prompt += args.added_prompt # clean, extremely detailed, best quality, sharp, clean
186
+ negative_prompt = args.negative_prompt #dirty, messy, low quality, frames, deformed,
187
+
188
+ if args.save_prompts:
189
+ txt_save_path = f"{txt_path}/{os.path.basename(image_name).split('.')[0]}.txt"
190
+ file = open(txt_save_path, "w")
191
+ file.write(validation_prompt)
192
+ file.close()
193
+ print(f'{validation_prompt}')
194
+
195
+ ori_width, ori_height = validation_image.size
196
+ resize_flag = False
197
+ rscale = args.upscale
198
+ if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
199
+ scale = (args.process_size//rscale)/min(ori_width, ori_height)
200
+ tmp_image = validation_image.resize((int(scale*ori_width), int(scale*ori_height)))
201
+
202
+ validation_image = tmp_image
203
+ resize_flag = True
204
+
205
+ validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
206
+ validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
207
+ width, height = validation_image.size
208
+ resize_flag = True #
209
+
210
+ print(f'input size: {height}x{width}')
211
+
212
+ for sample_idx in range(args.sample_times):
213
+ os.makedirs(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True)
214
+
215
+ for sample_idx in range(args.sample_times):
216
+ with torch.autocast("cuda"):
217
+ image = pipeline(
218
+ validation_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, height=height, width=width,
219
+ guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
220
+ start_point=args.start_point, ram_encoder_hidden_states=ram_encoder_hidden_states,
221
+ latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
222
+ args=args,
223
+ ).images[0]
224
+
225
+ if args.align_method == 'nofix':
226
+ image = image
227
+ else:
228
+ if args.align_method == 'wavelet':
229
+ image = wavelet_color_fix(image, validation_image)
230
+ elif args.align_method == 'adain':
231
+ image = adain_color_fix(image, validation_image)
232
+
233
+ if resize_flag:
234
+ image = image.resize((ori_width*rscale, ori_height*rscale))
235
+
236
+ name, ext = os.path.splitext(os.path.basename(image_name))
237
+
238
+ image.save(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/{name}.png')
239
+
240
+ if __name__ == "__main__":
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--seesr_model_path", type=str, default=None)
243
+ parser.add_argument("--ram_ft_path", type=str, default=None)
244
+ parser.add_argument("--pretrained_model_path", type=str, default=None)
245
+ parser.add_argument("--prompt", type=str, default="") # user can add self-prompt to improve the results
246
+ parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k")
247
+ parser.add_argument("--negative_prompt", type=str, default="dotted, noise, blur, lowres, smooth")
248
+ parser.add_argument("--image_path", type=str, default=None)
249
+ parser.add_argument("--output_dir", type=str, default=None)
250
+ parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16
251
+ parser.add_argument("--guidance_scale", type=float, default=1.0)
252
+ parser.add_argument("--conditioning_scale", type=float, default=1.0)
253
+ parser.add_argument("--blending_alpha", type=float, default=1.0)
254
+ parser.add_argument("--num_inference_steps", type=int, default=2)
255
+ parser.add_argument("--process_size", type=int, default=512)
256
+ parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) # latent size, for 24G
257
+ parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) # image size, for 13G
258
+ parser.add_argument("--latent_tiled_size", type=int, default=96)
259
+ parser.add_argument("--latent_tiled_overlap", type=int, default=32)
260
+ parser.add_argument("--upscale", type=int, default=4)
261
+ parser.add_argument("--seed", type=int, default=None)
262
+ parser.add_argument("--sample_times", type=int, default=1)
263
+ parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
264
+ parser.add_argument("--start_steps", type=int, default=999) # defaults set to 999.
265
+ parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr') # LR Embedding Strategy, choose 'lr latent + 999 steps noise' as diffusion start point.
266
+ parser.add_argument("--save_prompts", action='store_true')
267
+ args = parser.parse_args()
268
+ main(args)
269
+
270
+
271
+
train_seesr.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution
3
+ * Modified from diffusers by Rongyuan Wu
4
+ * 24/12/2023
5
+ '''
6
+
7
+ import argparse
8
+ import logging
9
+ import math
10
+ import os
11
+ import random
12
+ import shutil
13
+ from pathlib import Path
14
+
15
+ import accelerate
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ import transformers
21
+ from accelerate import Accelerator
22
+ from accelerate.logging import get_logger
23
+ from accelerate.utils import ProjectConfiguration, set_seed
24
+ from datasets import load_dataset # ''datasets'' is a library
25
+ from huggingface_hub import create_repo, upload_folder
26
+ from packaging import version
27
+ from PIL import Image
28
+ from torchvision import transforms
29
+ from tqdm.auto import tqdm
30
+ from transformers import AutoTokenizer, PretrainedConfig
31
+
32
+ import diffusers
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ DDPMScheduler,
36
+ StableDiffusionControlNetPipeline,
37
+ UniPCMultistepScheduler,
38
+ )
39
+ from models.controlnet import ControlNetModel
40
+ from models.unet_2d_condition import UNet2DConditionModel
41
+ from diffusers.optimization import get_scheduler
42
+ from diffusers.utils import check_min_version, is_wandb_available
43
+ from diffusers.utils.import_utils import is_xformers_available
44
+
45
+ from dataloaders.paired_dataset import PairedCaptionDataset
46
+
47
+ from typing import Mapping, Any
48
+ from torchvision import transforms
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+
52
+ if is_wandb_available():
53
+ import wandb
54
+
55
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56
+ check_min_version("0.21.0.dev0")
57
+
58
+ logger = get_logger(__name__)
59
+
60
+ from torchvision import transforms
61
+ tensor_transforms = transforms.Compose([
62
+ transforms.ToTensor(),
63
+ ])
64
+ ram_transforms = transforms.Compose([
65
+ transforms.Resize((384, 384)),
66
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
67
+ ])
68
+
69
+ def image_grid(imgs, rows, cols):
70
+ assert len(imgs) == rows * cols
71
+
72
+ w, h = imgs[0].size
73
+ grid = Image.new("RGB", size=(cols * w, rows * h))
74
+
75
+ for i, img in enumerate(imgs):
76
+ grid.paste(img, box=(i % cols * w, i // cols * h))
77
+ return grid
78
+
79
+
80
+ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
81
+ logger.info("Running validation... ")
82
+
83
+ controlnet = accelerator.unwrap_model(controlnet)
84
+
85
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
86
+ args.pretrained_model_name_or_path,
87
+ vae=vae,
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ unet=unet,
91
+ controlnet=controlnet,
92
+ safety_checker=None,
93
+ revision=args.revision,
94
+ torch_dtype=weight_dtype,
95
+ )
96
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
97
+ pipeline = pipeline.to(accelerator.device)
98
+ pipeline.set_progress_bar_config(disable=True)
99
+
100
+ if args.enable_xformers_memory_efficient_attention:
101
+ pipeline.enable_xformers_memory_efficient_attention()
102
+
103
+ if args.seed is None:
104
+ generator = None
105
+ else:
106
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
107
+
108
+ if len(args.validation_image) == len(args.validation_prompt):
109
+ validation_images = args.validation_image
110
+ validation_prompts = args.validation_prompt
111
+ elif len(args.validation_image) == 1:
112
+ validation_images = args.validation_image * len(args.validation_prompt)
113
+ validation_prompts = args.validation_prompt
114
+ elif len(args.validation_prompt) == 1:
115
+ validation_images = args.validation_image
116
+ validation_prompts = args.validation_prompt * len(args.validation_image)
117
+ else:
118
+ raise ValueError(
119
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
120
+ )
121
+
122
+ image_logs = []
123
+
124
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
125
+ validation_image = Image.open(validation_image).convert("RGB")
126
+
127
+ images = []
128
+
129
+ for _ in range(args.num_validation_images):
130
+ with torch.autocast("cuda"):
131
+ image = pipeline(
132
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
133
+ ).images[0]
134
+
135
+ images.append(image)
136
+
137
+ image_logs.append(
138
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
139
+ )
140
+
141
+ for tracker in accelerator.trackers:
142
+ if tracker.name == "tensorboard":
143
+ for log in image_logs:
144
+ images = log["images"]
145
+ validation_prompt = log["validation_prompt"]
146
+ validation_image = log["validation_image"]
147
+
148
+ formatted_images = []
149
+
150
+ formatted_images.append(np.asarray(validation_image))
151
+
152
+ for image in images:
153
+ formatted_images.append(np.asarray(image))
154
+
155
+ formatted_images = np.stack(formatted_images)
156
+
157
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
158
+ elif tracker.name == "wandb":
159
+ formatted_images = []
160
+
161
+ for log in image_logs:
162
+ images = log["images"]
163
+ validation_prompt = log["validation_prompt"]
164
+ validation_image = log["validation_image"]
165
+
166
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
167
+
168
+ for image in images:
169
+ image = wandb.Image(image, caption=validation_prompt)
170
+ formatted_images.append(image)
171
+
172
+ tracker.log({"validation": formatted_images})
173
+ else:
174
+ logger.warn(f"image logging not implemented for {tracker.name}")
175
+
176
+ return image_logs
177
+
178
+
179
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
180
+ text_encoder_config = PretrainedConfig.from_pretrained(
181
+ pretrained_model_name_or_path,
182
+ subfolder="text_encoder",
183
+ revision=revision,
184
+ )
185
+ model_class = text_encoder_config.architectures[0]
186
+
187
+ if model_class == "CLIPTextModel":
188
+ from transformers import CLIPTextModel
189
+
190
+ return CLIPTextModel
191
+ elif model_class == "RobertaSeriesModelWithTransformation":
192
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
193
+
194
+ return RobertaSeriesModelWithTransformation
195
+ else:
196
+ raise ValueError(f"{model_class} is not supported.")
197
+
198
+
199
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
200
+ img_str = ""
201
+ if image_logs is not None:
202
+ img_str = "You can find some example images below.\n"
203
+ for i, log in enumerate(image_logs):
204
+ images = log["images"]
205
+ validation_prompt = log["validation_prompt"]
206
+ validation_image = log["validation_image"]
207
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
208
+ img_str += f"prompt: {validation_prompt}\n"
209
+ images = [validation_image] + images
210
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
211
+ img_str += f"![images_{i})](./images_{i}.png)\n"
212
+
213
+ yaml = f"""
214
+ ---
215
+ license: creativeml-openrail-m
216
+ base_model: {base_model}
217
+ tags:
218
+ - stable-diffusion
219
+ - stable-diffusion-diffusers
220
+ - text-to-image
221
+ - diffusers
222
+ - controlnet
223
+ inference: true
224
+ ---
225
+ """
226
+ model_card = f"""
227
+ # controlnet-{repo_id}
228
+
229
+ These are controlnet weights trained on {base_model} with new type of conditioning.
230
+ {img_str}
231
+ """
232
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
233
+ f.write(yaml + model_card)
234
+
235
+
236
+ def parse_args(input_args=None):
237
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
238
+ parser.add_argument(
239
+ "--pretrained_model_name_or_path",
240
+ type=str,
241
+ default="/home/notebook/data/group/LowLevelLLM/models/diffusion_models/stable-diffusion-2-base",
242
+ # required=True,
243
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
244
+ )
245
+ parser.add_argument(
246
+ "--controlnet_model_name_or_path",
247
+ type=str,
248
+ default=None,
249
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
250
+ " If not specified controlnet weights are initialized from unet.",
251
+ )
252
+ parser.add_argument(
253
+ "--unet_model_name_or_path",
254
+ type=str,
255
+ default=None,
256
+ help="Path to pretrained unet model or model identifier from huggingface.co/models."
257
+ " If not specified controlnet weights are initialized from unet.",
258
+ )
259
+ parser.add_argument(
260
+ "--revision",
261
+ type=str,
262
+ default=None,
263
+ required=False,
264
+ help=(
265
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
266
+ " float32 precision."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--tokenizer_name",
271
+ type=str,
272
+ default=None,
273
+ help="Pretrained tokenizer name or path if not the same as model_name",
274
+ )
275
+ parser.add_argument(
276
+ "--output_dir",
277
+ type=str,
278
+ default="./experience/test",
279
+ help="The output directory where the model predictions and checkpoints will be written.",
280
+ )
281
+ parser.add_argument(
282
+ "--cache_dir",
283
+ type=str,
284
+ default=None,
285
+ help="The directory where the downloaded models and datasets will be stored.",
286
+ )
287
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
288
+ parser.add_argument(
289
+ "--resolution",
290
+ type=int,
291
+ default=512,
292
+ help=(
293
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
294
+ " resolution"
295
+ ),
296
+ )
297
+ parser.add_argument(
298
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
299
+ )
300
+ parser.add_argument("--num_train_epochs", type=int, default=1000)
301
+ parser.add_argument(
302
+ "--max_train_steps",
303
+ type=int,
304
+ default=None,
305
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
306
+ )
307
+ parser.add_argument(
308
+ "--checkpointing_steps",
309
+ type=int,
310
+ default=500,
311
+ help=(
312
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
313
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
314
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
315
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
316
+ "instructions."
317
+ ),
318
+ )
319
+ parser.add_argument(
320
+ "--checkpoints_total_limit",
321
+ type=int,
322
+ default=None,
323
+ help=("Max number of checkpoints to store."),
324
+ )
325
+ parser.add_argument(
326
+ "--resume_from_checkpoint",
327
+ type=str,
328
+ default=None,
329
+ help=(
330
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
331
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
332
+ ),
333
+ )
334
+ parser.add_argument(
335
+ "--gradient_accumulation_steps",
336
+ type=int,
337
+ default=1,
338
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
339
+ )
340
+ parser.add_argument(
341
+ "--gradient_checkpointing",
342
+ action="store_true",
343
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
344
+ )
345
+ parser.add_argument(
346
+ "--learning_rate",
347
+ type=float,
348
+ default=5e-5,
349
+ help="Initial learning rate (after the potential warmup period) to use.",
350
+ )
351
+ parser.add_argument(
352
+ "--scale_lr",
353
+ action="store_true",
354
+ default=False,
355
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
356
+ )
357
+ parser.add_argument(
358
+ "--lr_scheduler",
359
+ type=str,
360
+ default="constant",
361
+ help=(
362
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
363
+ ' "constant", "constant_with_warmup"]'
364
+ ),
365
+ )
366
+ parser.add_argument(
367
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
368
+ )
369
+ parser.add_argument(
370
+ "--lr_num_cycles",
371
+ type=int,
372
+ default=1,
373
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
374
+ )
375
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
376
+ parser.add_argument(
377
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
378
+ )
379
+ parser.add_argument(
380
+ "--dataloader_num_workers",
381
+ type=int,
382
+ default=0,
383
+ help=(
384
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
385
+ ),
386
+ )
387
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
388
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
389
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
390
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
391
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
392
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
393
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
394
+ parser.add_argument(
395
+ "--hub_model_id",
396
+ type=str,
397
+ default=None,
398
+ help="The name of the repository to keep in sync with the local `output_dir`.",
399
+ )
400
+ parser.add_argument(
401
+ "--logging_dir",
402
+ type=str,
403
+ default="logs",
404
+ help=(
405
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
406
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
407
+ ),
408
+ )
409
+ parser.add_argument(
410
+ "--allow_tf32",
411
+ action="store_true",
412
+ help=(
413
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
414
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
415
+ ),
416
+ )
417
+ parser.add_argument(
418
+ "--report_to",
419
+ type=str,
420
+ default="tensorboard",
421
+ help=(
422
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
423
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
424
+ ),
425
+ )
426
+ parser.add_argument(
427
+ "--mixed_precision",
428
+ type=str,
429
+ default="fp16",
430
+ choices=["no", "fp16", "bf16"],
431
+ help=(
432
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
433
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
434
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
435
+ ),
436
+ )
437
+ parser.add_argument(
438
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
439
+ )
440
+ parser.add_argument(
441
+ "--set_grads_to_none",
442
+ action="store_true",
443
+ help=(
444
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
445
+ " behaviors, so disable this argument if it causes any problems. More info:"
446
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
447
+ ),
448
+ )
449
+ parser.add_argument(
450
+ "--dataset_name",
451
+ type=str,
452
+ default=None,
453
+ help=(
454
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
455
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
456
+ " or to a folder containing files that 🤗 Datasets can understand."
457
+ ),
458
+ )
459
+ parser.add_argument(
460
+ "--dataset_config_name",
461
+ type=str,
462
+ default=None,
463
+ help="The config of the Dataset, leave as None if there's only one config.",
464
+ )
465
+ parser.add_argument(
466
+ "--train_data_dir",
467
+ type=str,
468
+ default='NOTHING',
469
+ help=(
470
+ "A folder containing the training data. Folder contents must follow the structure described in"
471
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
472
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
473
+ ),
474
+ )
475
+ parser.add_argument(
476
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
477
+ )
478
+ parser.add_argument(
479
+ "--conditioning_image_column",
480
+ type=str,
481
+ default="conditioning_image",
482
+ help="The column of the dataset containing the controlnet conditioning image.",
483
+ )
484
+ parser.add_argument(
485
+ "--caption_column",
486
+ type=str,
487
+ default="text",
488
+ help="The column of the dataset containing a caption or a list of captions.",
489
+ )
490
+ parser.add_argument(
491
+ "--max_train_samples",
492
+ type=int,
493
+ default=None,
494
+ help=(
495
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
496
+ "value if set."
497
+ ),
498
+ )
499
+ parser.add_argument(
500
+ "--proportion_empty_prompts",
501
+ type=float,
502
+ default=0,
503
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
504
+ )
505
+ parser.add_argument(
506
+ "--validation_prompt",
507
+ type=str,
508
+ default=[""],
509
+ nargs="+",
510
+ help=(
511
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
512
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
513
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
514
+ ),
515
+ )
516
+ parser.add_argument(
517
+ "--validation_image",
518
+ type=str,
519
+ default=[""],
520
+ nargs="+",
521
+ help=(
522
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
523
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
524
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
525
+ " `--validation_image` that will be used with all `--validation_prompt`s."
526
+ ),
527
+ )
528
+ parser.add_argument(
529
+ "--num_validation_images",
530
+ type=int,
531
+ default=4,
532
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
533
+ )
534
+ parser.add_argument(
535
+ "--validation_steps",
536
+ type=int,
537
+ default=1,
538
+ help=(
539
+ "Run validation every X steps. Validation consists of running the prompt"
540
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
541
+ " and logging the images."
542
+ ),
543
+ )
544
+ parser.add_argument(
545
+ "--tracker_project_name",
546
+ type=str,
547
+ default="SeeSR",
548
+ help=(
549
+ "The `project_name` argument passed to Accelerator.init_trackers for"
550
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
551
+ ),
552
+ )
553
+
554
+ parser.add_argument("--root_folders", type=str , default='' )
555
+ parser.add_argument("--null_text_ratio", type=float, default=0.5)
556
+ parser.add_argument("--ram_ft_path", type=str, default=None)
557
+ parser.add_argument('--trainable_modules', nargs='*', type=str, default=["image_attentions"])
558
+
559
+
560
+
561
+
562
+ if input_args is not None:
563
+ args = parser.parse_args(input_args)
564
+ else:
565
+ args = parser.parse_args()
566
+
567
+ if args.dataset_name is None and args.train_data_dir is None:
568
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
569
+
570
+ if args.dataset_name is not None and args.train_data_dir is not None:
571
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
572
+
573
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
574
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
575
+
576
+ if args.validation_prompt is not None and args.validation_image is None:
577
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
578
+
579
+ if args.validation_prompt is None and args.validation_image is not None:
580
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
581
+
582
+ if (
583
+ args.validation_image is not None
584
+ and args.validation_prompt is not None
585
+ and len(args.validation_image) != 1
586
+ and len(args.validation_prompt) != 1
587
+ and len(args.validation_image) != len(args.validation_prompt)
588
+ ):
589
+ raise ValueError(
590
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
591
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
592
+ )
593
+
594
+ if args.resolution % 8 != 0:
595
+ raise ValueError(
596
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
597
+ )
598
+
599
+ return args
600
+
601
+
602
+ # def main(args):
603
+ args = parse_args()
604
+ logging_dir = Path(args.output_dir, args.logging_dir)
605
+
606
+
607
+ from accelerate import DistributedDataParallelKwargs
608
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
609
+
610
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
611
+
612
+ accelerator = Accelerator(
613
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
614
+ mixed_precision=args.mixed_precision,
615
+ log_with=args.report_to,
616
+ project_config=accelerator_project_config,
617
+ kwargs_handlers=[ddp_kwargs]
618
+ )
619
+
620
+ # Make one log on every process with the configuration for debugging.
621
+ logging.basicConfig(
622
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
623
+ datefmt="%m/%d/%Y %H:%M:%S",
624
+ level=logging.INFO,
625
+ )
626
+ logger.info(accelerator.state, main_process_only=False)
627
+ if accelerator.is_local_main_process:
628
+ transformers.utils.logging.set_verbosity_warning()
629
+ diffusers.utils.logging.set_verbosity_info()
630
+ else:
631
+ transformers.utils.logging.set_verbosity_error()
632
+ diffusers.utils.logging.set_verbosity_error()
633
+
634
+ # If passed along, set the training seed now.
635
+ if args.seed is not None:
636
+ set_seed(args.seed)
637
+
638
+ # Handle the repository creation
639
+ if accelerator.is_main_process:
640
+ if args.output_dir is not None:
641
+ os.makedirs(args.output_dir, exist_ok=True)
642
+
643
+ if args.push_to_hub:
644
+ repo_id = create_repo(
645
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
646
+ ).repo_id
647
+
648
+ # Load the tokenizer
649
+ if args.tokenizer_name:
650
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
651
+ elif args.pretrained_model_name_or_path:
652
+ tokenizer = AutoTokenizer.from_pretrained(
653
+ args.pretrained_model_name_or_path,
654
+ subfolder="tokenizer",
655
+ revision=args.revision,
656
+ use_fast=False,
657
+ )
658
+
659
+ # import correct text encoder class
660
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
661
+
662
+ # Load scheduler and models
663
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
664
+ text_encoder = text_encoder_cls.from_pretrained(
665
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
666
+ )
667
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
668
+ # unet = UNet2DConditionModel.from_pretrained(
669
+ # args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
670
+ # )
671
+ if args.unet_model_name_or_path:
672
+ # resume from self-train
673
+ logger.info("Loading unet weights from self-train")
674
+ unet = UNet2DConditionModel.from_pretrained_orig(
675
+ args.pretrained_model_name_or_path, args.unet_model_name_or_path, subfolder="unet", revision=args.revision, use_image_cross_attention=True
676
+ )
677
+ else:
678
+ # resume from pretrained SD
679
+ logger.info("Loading unet weights from SD")
680
+ unet = UNet2DConditionModel.from_pretrained(
681
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_image_cross_attention=True
682
+ )
683
+ print(f'===== if use ram encoder? {unet.config.use_image_cross_attention}')
684
+
685
+ if args.controlnet_model_name_or_path:
686
+ # resume from self-train
687
+ logger.info("Loading existing controlnet weights")
688
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet")
689
+
690
+ else:
691
+ logger.info("Initializing controlnet weights from unet")
692
+ controlnet = ControlNetModel.from_unet(unet, use_image_cross_attention=True)
693
+
694
+
695
+ # `accelerate` 0.16.0 will have better support for customized saving
696
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
697
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
698
+ def save_model_hook(models, weights, output_dir):
699
+ i = len(weights) - 1
700
+
701
+ # while len(weights) > 0:
702
+ # weights.pop()
703
+ # model = models[i]
704
+
705
+ # sub_dir = "controlnet"
706
+ # model.save_pretrained(os.path.join(output_dir, sub_dir))
707
+
708
+ # i -= 1
709
+ assert len(models) == 2 and len(weights) == 2
710
+ for i, model in enumerate(models):
711
+ sub_dir = "unet" if isinstance(model, UNet2DConditionModel) else "controlnet"
712
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
713
+ # make sure to pop weight so that corresponding model is not saved again
714
+ weights.pop()
715
+
716
+ def load_model_hook(models, input_dir):
717
+ # while len(models) > 0:
718
+ # # pop models so that they are not loaded again
719
+ # model = models.pop()
720
+
721
+ # # load diffusers style into model
722
+ # load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
723
+ # model.register_to_config(**load_model.config)
724
+
725
+ # model.load_state_dict(load_model.state_dict())
726
+ # del load_model
727
+ assert len(models) == 2
728
+ for i in range(len(models)):
729
+ # pop models so that they are not loaded again
730
+ model = models.pop()
731
+
732
+ # load diffusers style into model
733
+ if not isinstance(model, UNet2DConditionModel):
734
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True
735
+ else:
736
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True
737
+
738
+ model.register_to_config(**load_model.config)
739
+
740
+ model.load_state_dict(load_model.state_dict())
741
+ del load_model
742
+
743
+ accelerator.register_save_state_pre_hook(save_model_hook)
744
+ accelerator.register_load_state_pre_hook(load_model_hook)
745
+
746
+ vae.requires_grad_(False)
747
+ unet.requires_grad_(False)
748
+ text_encoder.requires_grad_(False)
749
+ controlnet.train()
750
+
751
+ ## release the cross-attention part in the unet.
752
+ for name, module in unet.named_modules():
753
+ if name.endswith(tuple(args.trainable_modules)):
754
+ print(f'{name} in <unet> will be optimized.' )
755
+ for params in module.parameters():
756
+ params.requires_grad = True
757
+
758
+ ## init the RAM or DAPE model
759
+ from ram.models.ram_lora import ram
760
+ from ram import get_transform
761
+ if args.ram_ft_path is None:
762
+ print("======== USE Original RAM ========")
763
+ else:
764
+ print("==============")
765
+ print(f"USE FT RAM FROM: {args.ram_ft_path}")
766
+ print("==============")
767
+
768
+ RAM = ram(pretrained='preset/models/ram_swin_large_14m.pth',
769
+ pretrained_condition=args.ram_ft_path,
770
+ image_size=384,
771
+ vit='swin_l')
772
+ RAM.eval()
773
+
774
+ if args.enable_xformers_memory_efficient_attention:
775
+ if is_xformers_available():
776
+ import xformers
777
+
778
+ xformers_version = version.parse(xformers.__version__)
779
+ if xformers_version == version.parse("0.0.16"):
780
+ logger.warn(
781
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
782
+ )
783
+ unet.enable_xformers_memory_efficient_attention()
784
+ controlnet.enable_xformers_memory_efficient_attention()
785
+ else:
786
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
787
+
788
+ if args.gradient_checkpointing:
789
+ unet.enable_gradient_checkpointing()
790
+ controlnet.enable_gradient_checkpointing()
791
+
792
+ # Check that all trainable models are in full precision
793
+ low_precision_error_string = (
794
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
795
+ " doing mixed precision training, copy of the weights should still be float32."
796
+ )
797
+
798
+ if accelerator.unwrap_model(controlnet).dtype != torch.float32:
799
+ raise ValueError(
800
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}"
801
+ )
802
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
803
+ raise ValueError(
804
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
805
+ )
806
+
807
+
808
+ # Enable TF32 for faster training on Ampere GPUs,
809
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
810
+ if args.allow_tf32:
811
+ torch.backends.cuda.matmul.allow_tf32 = True
812
+
813
+ if args.scale_lr:
814
+ args.learning_rate = (
815
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
816
+ )
817
+
818
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
819
+ if args.use_8bit_adam:
820
+ try:
821
+ import bitsandbytes as bnb
822
+ except ImportError:
823
+ raise ImportError(
824
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
825
+ )
826
+
827
+ optimizer_class = bnb.optim.AdamW8bit
828
+ else:
829
+ optimizer_class = torch.optim.AdamW
830
+
831
+ # Optimizer creation
832
+ print(f'=================Optimize ControlNet and Unet ======================')
833
+ params_to_optimize = list(controlnet.parameters()) + list(unet.parameters())
834
+
835
+
836
+ print(f'start to load optimizer...')
837
+
838
+ optimizer = optimizer_class(
839
+ params_to_optimize,
840
+ lr=args.learning_rate,
841
+ betas=(args.adam_beta1, args.adam_beta2),
842
+ weight_decay=args.adam_weight_decay,
843
+ eps=args.adam_epsilon,
844
+ )
845
+
846
+ train_dataset = PairedCaptionDataset(root_folders=args.root_folders,
847
+ tokenizer=tokenizer,
848
+ null_text_ratio=args.null_text_ratio,
849
+ )
850
+
851
+ train_dataloader = torch.utils.data.DataLoader(
852
+ train_dataset,
853
+ num_workers=args.dataloader_num_workers,
854
+ batch_size=args.train_batch_size,
855
+ shuffle=False
856
+ )
857
+
858
+
859
+ # Scheduler and math around the number of training steps.
860
+ overrode_max_train_steps = False
861
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
862
+ if args.max_train_steps is None:
863
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
864
+ overrode_max_train_steps = True
865
+
866
+ lr_scheduler = get_scheduler(
867
+ args.lr_scheduler,
868
+ optimizer=optimizer,
869
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
870
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
871
+ num_cycles=args.lr_num_cycles,
872
+ power=args.lr_power,
873
+ )
874
+
875
+ # Prepare everything with our `accelerator`.
876
+ controlnet, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
877
+ controlnet, unet, optimizer, train_dataloader, lr_scheduler
878
+ )
879
+
880
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
881
+ # as these models are only used for inference, keeping weights in full precision is not required.
882
+ weight_dtype = torch.float32
883
+ if accelerator.mixed_precision == "fp16":
884
+ weight_dtype = torch.float16
885
+ elif accelerator.mixed_precision == "bf16":
886
+ weight_dtype = torch.bfloat16
887
+
888
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
889
+ vae.to(accelerator.device, dtype=weight_dtype)
890
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
891
+ RAM.to(accelerator.device, dtype=weight_dtype)
892
+
893
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
894
+ if overrode_max_train_steps:
895
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
896
+ # Afterwards we recalculate our number of training epochs
897
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
898
+
899
+
900
+ # We need to initialize the trackers we use, and also store our configuration.
901
+ # The trackers initializes automatically on the main process.
902
+ if accelerator.is_main_process:
903
+ tracker_config = dict(vars(args))
904
+
905
+ # tensorboard cannot handle list types for config
906
+ tracker_config.pop("validation_prompt")
907
+ tracker_config.pop("validation_image")
908
+ tracker_config.pop("trainable_modules")
909
+
910
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
911
+
912
+ # Train!
913
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
914
+
915
+ logger.info("***** Running training *****")
916
+ # if not isinstance(train_dataset, WebImageDataset):
917
+ # logger.info(f" Num examples = {len(train_dataset)}")
918
+ # logger.info(f" Num batches each epoch = {len(train_dataloader)}")
919
+
920
+
921
+ logger.info(f" Num examples = {len(train_dataset)}")
922
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
923
+
924
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
925
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
926
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
927
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
928
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
929
+ global_step = 0
930
+ first_epoch = 0
931
+
932
+ # Potentially load in the weights and states from a previous save
933
+ if args.resume_from_checkpoint:
934
+ if args.resume_from_checkpoint != "latest":
935
+ path = os.path.basename(args.resume_from_checkpoint)
936
+ else:
937
+ # Get the most recent checkpoint
938
+ dirs = os.listdir(args.output_dir)
939
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
940
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
941
+ path = dirs[-1] if len(dirs) > 0 else None
942
+
943
+ if path is None:
944
+ accelerator.print(
945
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
946
+ )
947
+ args.resume_from_checkpoint = None
948
+ initial_global_step = 0
949
+ else:
950
+ accelerator.print(f"Resuming from checkpoint {path}")
951
+ accelerator.load_state(os.path.join(args.output_dir, path))
952
+ global_step = int(path.split("-")[1])
953
+
954
+ initial_global_step = global_step
955
+ first_epoch = global_step // num_update_steps_per_epoch
956
+ else:
957
+ initial_global_step = 0
958
+
959
+ progress_bar = tqdm(
960
+ range(0, args.max_train_steps),
961
+ initial=initial_global_step,
962
+ desc="Steps",
963
+ # Only show the progress bar once on each machine.
964
+ disable=not accelerator.is_local_main_process,
965
+ )
966
+
967
+
968
+ for epoch in range(first_epoch, args.num_train_epochs):
969
+ for step, batch in enumerate(train_dataloader):
970
+ # with accelerator.accumulate(controlnet):
971
+ with accelerator.accumulate(controlnet), accelerator.accumulate(unet):
972
+ pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)
973
+ # Convert images to latent space
974
+ latents = vae.encode(pixel_values).latent_dist.sample()
975
+ latents = latents * vae.config.scaling_factor
976
+
977
+ # Sample noise that we'll add to the latents
978
+ noise = torch.randn_like(latents)
979
+ bsz = latents.shape[0]
980
+ # Sample a random timestep for each image
981
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
982
+ timesteps = timesteps.long()
983
+
984
+ # Add noise to the latents according to the noise magnitude at each timestep
985
+ # (this is the forward diffusion process)
986
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
987
+
988
+ # # Get the text embedding for conditioning
989
+ encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]
990
+
991
+ controlnet_image = batch["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype)
992
+
993
+ # extract soft semantic label
994
+ with torch.no_grad():
995
+ ram_image = batch["ram_values"].to(accelerator.device, dtype=weight_dtype)
996
+ ram_encoder_hidden_states = RAM.generate_image_embeds(ram_image)
997
+
998
+ down_block_res_samples, mid_block_res_sample = controlnet(
999
+ noisy_latents,
1000
+ timesteps,
1001
+ encoder_hidden_states=encoder_hidden_states,
1002
+ controlnet_cond=controlnet_image,
1003
+ return_dict=False,
1004
+ image_encoder_hidden_states=ram_encoder_hidden_states,
1005
+ )
1006
+
1007
+ # Predict the noise residual
1008
+ model_pred = unet(
1009
+ noisy_latents,
1010
+ timesteps,
1011
+ encoder_hidden_states=encoder_hidden_states,
1012
+ down_block_additional_residuals=[
1013
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1014
+ ],
1015
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1016
+ image_encoder_hidden_states=ram_encoder_hidden_states,
1017
+ ).sample
1018
+
1019
+ # Get the target for loss depending on the prediction type
1020
+ if noise_scheduler.config.prediction_type == "epsilon":
1021
+ target = noise
1022
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1023
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1024
+ else:
1025
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1026
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1027
+
1028
+ accelerator.backward(loss)
1029
+ if accelerator.sync_gradients:
1030
+ # params_to_clip = controlnet.parameters()
1031
+ params_to_clip = list(controlnet.parameters()) + list(unet.parameters())
1032
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1033
+ optimizer.step()
1034
+ lr_scheduler.step()
1035
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1036
+
1037
+ # Checks if the accelerator has performed an optimization step behind the scenes
1038
+ if accelerator.sync_gradients:
1039
+ progress_bar.update(1)
1040
+ global_step += 1
1041
+
1042
+ if accelerator.is_main_process:
1043
+ if global_step % args.checkpointing_steps == 0:
1044
+
1045
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1046
+ accelerator.save_state(save_path)
1047
+ logger.info(f"Saved state to {save_path}")
1048
+
1049
+ # if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1050
+ if False:
1051
+ image_logs = log_validation(
1052
+ vae,
1053
+ text_encoder,
1054
+ tokenizer,
1055
+ unet,
1056
+ controlnet,
1057
+ args,
1058
+ accelerator,
1059
+ weight_dtype,
1060
+ global_step,
1061
+ )
1062
+
1063
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1064
+ progress_bar.set_postfix(**logs)
1065
+ accelerator.log(logs, step=global_step)
1066
+
1067
+ if global_step >= args.max_train_steps:
1068
+ break
1069
+
1070
+ # Create the pipeline using using the trained modules and save it.
1071
+ accelerator.wait_for_everyone()
1072
+ if accelerator.is_main_process:
1073
+ controlnet = accelerator.unwrap_model(controlnet)
1074
+ controlnet.save_pretrained(args.output_dir)
1075
+
1076
+ unet = accelerator.unwrap_model(unet)
1077
+ unet.save_pretrained(args.output_dir)
1078
+
1079
+ if args.push_to_hub:
1080
+ save_model_card(
1081
+ repo_id,
1082
+ image_logs=image_logs,
1083
+ base_model=args.pretrained_model_name_or_path,
1084
+ repo_folder=args.output_dir,
1085
+ )
1086
+ upload_folder(
1087
+ repo_id=repo_id,
1088
+ folder_path=args.output_dir,
1089
+ commit_message="End of training",
1090
+ ignore_patterns=["step_*", "epoch_*"],
1091
+ )
1092
+
1093
+ accelerator.end_training()