diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d72f32a6d3d3e31c67cad1aa11988cade1862349 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +ckpt/* filter=lfs diff=lfs merge=lfs -text +hf_demo/examples/* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.idea/HarmonyView.iml b/.idea/HarmonyView.iml new file mode 100644 index 0000000000000000000000000000000000000000..d0876a78d06ac03b5d78c8dcdb95570281c6f1d6 --- /dev/null +++ b/.idea/HarmonyView.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..aa6c95a57c42b422a129738cb6366038ca9fcf61 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,103 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..1ad2a344d00dd6cf32075d0144fddf17fbbe4065 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..d24c28f2b182c329ef5e8e128bba0ae991a8539a --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..35eb1ddfbbc029bcab630581847471d7f238ec53 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000000000000000000000000000000000000..c33801371a93378cd31fe4d0a18a35636e5e4838 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + { + "associatedIndex": 6 +} + + + + { + "keyToString": { + "RunOnceActivity.OpenProjectViewOnStart": "true", + "RunOnceActivity.ShowReadmeOnStart": "true", + "git-widget-placeholder": "main", + "last_opened_file_path": "/home/byeongjun/PycharmProjects/HarmonyView" + } +} + + + + + + + + + + + + + + + + 1703058146297 + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 0768ed3b5d774be4d1bd96bcfbdaf47029a2caaa..9e81291a999adbc1615590421820ba1379ed02df 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ --- title: HarmonyView -emoji: 🏃 -colorFrom: gray -colorTo: red +emoji: 🚀 +colorFrom: indigo +colorTo: pink sdk: gradio -sdk_version: 4.11.0 +sdk_version: 3.43.2 app_file: app.py pinned: false -license: mit +license: cc-by-sa-3.0 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..716135fb2072b3fe2a3fba6111d4683e1ec30761 --- /dev/null +++ b/app.py @@ -0,0 +1,261 @@ +from functools import partial + +from PIL import Image +import numpy as np +import gradio as gr +import torch +import os +import fire +from omegaconf import OmegaConf + +from ldm.models.diffusion.sync_dreamer import SyncDDIMSampler, SyncMultiviewDiffusion +from ldm.util import add_margin, instantiate_from_config +from sam_utils import sam_init, sam_out_nosave + +import torch +_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image''' +_DESCRIPTION = ''' +
+ + + +
+Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
+ +Procedure:
+**Step 1**. Upload an image or select an example. ==> The foreground is masked out by SAM and we crop it as inputs.
+**Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. The **Elevation angle** is the elevation of the input image. (This costs about 30s.)
+You may adjust the **Crop size** and **Elevation angle** to get a better result!
+To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
+We have heavily borrowed codes from [One-2-3-45](https://huggingface.co/spaces/One-2-3-45/One-2-3-45), which is also an amazing single-view reconstruction method. +''' +_USER_GUIDE0 = "Step1: Please upload an image in the block above (or choose an example shown in the left)." +# _USER_GUIDE1 = "Step1: Please select a **Crop size** and click **Crop it**." +_USER_GUIDE2 = "Step2: Please choose a **Elevation angle** and click **Run Generate**. The **Elevation angle** is the elevation of the input image. This costs about 30s." +_USER_GUIDE3 = "Generated multiview images are shown below! (You may adjust the **Crop size** and **Elevation angle** to get a better result!)" + +others = '''**Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized.
''' + +deployed = True + +if deployed: + print(f"Is CUDA available: {torch.cuda.is_available()}") + print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") + + +class BackgroundRemoval: + def __init__(self, device='cuda'): + from carvekit.api.high import HiInterface + self.interface = HiInterface( + object_type="object", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device=device, + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=True, + ) + + @torch.no_grad() + def __call__(self, image): + # image: [H, W, 3] array in [0, 255]. + image = self.interface([image])[0] + return image + +def resize_inputs(image_input, crop_size): + if image_input is None: return None + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) + results = add_margin(ref_img_, size=256) + return results + +def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input): + if deployed: + assert isinstance(model, SyncMultiviewDiffusion) + seed=int(seed) + torch.random.manual_seed(seed) + np.random.seed(seed) + + # prepare data + image_input = np.asarray(image_input) + image_input = image_input.astype(np.float32) / 255.0 + alpha_values = image_input[:,:, 3:] + image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background + image_input = image_input[:, :, :3] * 2.0 - 1.0 + image_input = torch.from_numpy(image_input.astype(np.float32)) + elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) + data = {"input_image": image_input, "input_elevation": elevation_input} + for k, v in data.items(): + if deployed: + data[k] = v.unsqueeze(0).cuda() + else: + data[k] = v.unsqueeze(0) + data[k] = torch.repeat_interleave(data[k], sample_num, dim=0) + + if deployed: + sampler = SyncDDIMSampler(model, sample_steps) + x_sample = model.sample(sampler, data, cfg_scale, batch_view_num) + else: + x_sample = torch.zeros(sample_num, 16, 3, 256, 256) + + B, N, _, H, W = x_sample.shape + x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 + x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 + x_sample = x_sample.astype(np.uint8) + + results = [] + for bi in range(B): + results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1)) + results = np.concatenate(results, 0) + return Image.fromarray(results) + else: + return Image.fromarray(np.zeros([sample_num*256,16*256,3],np.uint8)) + + +def sam_predict(predictor, removal, raw_im): + if raw_im is None: return None + if deployed: + raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS) + image_nobg = removal(raw_im.convert('RGB')) + arr = np.asarray(image_nobg)[:, :, -1] + x_nonzero = np.nonzero(arr.sum(axis=0)) + y_nonzero = np.nonzero(arr.sum(axis=1)) + x_min = int(x_nonzero[0].min()) + y_min = int(y_nonzero[0].min()) + x_max = int(x_nonzero[0].max()) + y_max = int(y_nonzero[0].max()) + # image_nobg.save('./nobg.png') + + image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS) + image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max)) + + # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255) + image_sam = np.asarray(image_sam, np.float32) / 255 + out_mask = image_sam[:, :, 3:] + out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask + out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8) + + image_sam = Image.fromarray(out_img, mode='RGBA') + # image_sam.save('./output.png') + torch.cuda.empty_cache() + return image_sam + else: + return raw_im + +def run_demo(): + # device = f"cuda:0" if torch.cuda.is_available() else "cpu" + # models = None # init_model(device, os.path.join(code_dir, ckpt)) + cfg = 'configs/syncdreamer.yaml' + ckpt = 'ckpt/syncdreamer-pretrain.ckpt' + config = OmegaConf.load(cfg) + # model = None + if deployed: + model = instantiate_from_config(config.model) + print(f'loading model from {ckpt} ...') + ckpt = torch.load(ckpt,map_location='cpu') + model.load_state_dict(ckpt['state_dict'], strict=True) + model = model.cuda().eval() + del ckpt + mask_predictor = sam_init() + removal = BackgroundRemoval() + else: + model = None + mask_predictor = None + removal = None + + # NOTE: Examples must match inputs + examples_full = [ + ['hf_demo/examples/monkey.png',30,200], + ['hf_demo/examples/cat.png',30,200], + ['hf_demo/examples/crab.png',30,200], + ['hf_demo/examples/elephant.png',30,200], + ['hf_demo/examples/flower.png',0,200], + ['hf_demo/examples/forest.png',30,200], + ['hf_demo/examples/teapot.png',20,200], + ['hf_demo/examples/basket.png',30,200], + ] + + image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True) + elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle of the input image', interactive=True) + crop_size = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True) + + # Compose demo layout & data flow. + with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + # with gr.Column(scale=0): + # gr.DuplicateButton(value='Duplicate Space for private use', elem_id='duplicate-button') + gr.Markdown(_DESCRIPTION) + + with gr.Row(variant='panel'): + with gr.Column(scale=1.2): + gr.Examples( + examples=examples_full, # NOTE: elements must match inputs list! + inputs=[image_block, elevation, crop_size], + outputs=[image_block, elevation, crop_size], + cache_examples=False, + label='Examples (click one of the images below to start)', + examples_per_page=5, + ) + + with gr.Column(scale=0.8): + image_block.render() + guide_text = gr.Markdown(_USER_GUIDE0, visible=True) + fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) + + + with gr.Column(scale=0.8): + sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False) + crop_size.render() + # crop_btn = gr.Button('Crop it', variant='primary', interactive=True) + fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) + + with gr.Column(scale=0.8): + input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False) + elevation.render() + with gr.Accordion('Advanced options', open=False): + cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True) + sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)') + sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False) + batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True) + seed = gr.Number(6033, label='Random seed', interactive=True) + run_btn = gr.Button('Run generation', variant='primary', interactive=True) + + + output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False) + + def update_guide2(text, im): + if im is None: + return _USER_GUIDE0 + else: + return text + update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT) + + image_block.clear(fn=partial(update_guide, _USER_GUIDE0), outputs=[guide_text], queue=False) + image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=True) \ + .success(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=True)\ + .success(fn=partial(update_guide2, _USER_GUIDE2), inputs=[image_block], outputs=[guide_text], queue=False)\ + + crop_size.change(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=True)\ + .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False) + # crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\ + # .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False) + + run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=True)\ + .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False) + + demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD']) + +if __name__=="__main__": + fire.Fire(run_demo) \ No newline at end of file diff --git a/assets/crop_size.jpg b/assets/crop_size.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26f8db97cf755c0a6d03c3d47bbb60ec539d44e8 Binary files /dev/null and b/assets/crop_size.jpg differ diff --git a/assets/elevation.jpg b/assets/elevation.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0f5118802648434faa7adda4562587c95175be37 Binary files /dev/null and b/assets/elevation.jpg differ diff --git a/assets/teaser.jpg b/assets/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33effe4df3549204ea787b93382a428888aaf910 Binary files /dev/null and b/assets/teaser.jpg differ diff --git a/ckpt/ViT-L-14.pt b/ckpt/ViT-L-14.pt new file mode 100644 index 0000000000000000000000000000000000000000..a68f290e87e7dbf598480a76939629f6b1f08fc3 --- /dev/null +++ b/ckpt/ViT-L-14.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836 +size 932768134 diff --git a/ckpt/sam_vit_h_4b8939.pth b/ckpt/sam_vit_h_4b8939.pth new file mode 100644 index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72 --- /dev/null +++ b/ckpt/sam_vit_h_4b8939.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e +size 2564550879 diff --git a/ckpt/syncdreamer-pretrain.ckpt b/ckpt/syncdreamer-pretrain.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..e5c2fd3240c15254a5b5e50cf9930e6fe5d1cae6 --- /dev/null +++ b/ckpt/syncdreamer-pretrain.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ebb31334d9e4002b2590dd805e25238beaf95fa082f6e39a132344624448dcb +size 5570034171 diff --git a/configs/nerf.yaml b/configs/nerf.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75c11dca399f806040c2a1d42bda9d42b4533be3 --- /dev/null +++ b/configs/nerf.yaml @@ -0,0 +1,25 @@ +model: + base_lr: 1.0e-2 + target: renderer.renderer.RendererTrainer + params: + total_steps: 2000 + warm_up_steps: 100 + train_batch_num: 40960 + test_batch_num: 40960 + renderer: ngp + cube_bound: 0.6 + use_mask: true + lambda_rgb_loss: 0.5 + lambda_mask_loss: 10.0 + +data: + target: renderer.dummy_dataset.DummyDataset + params: {} + +callbacks: + save_interval: 5000 + +trainer: + val_check_interval: 500 + max_steps: 2000 + diff --git a/configs/neus.yaml b/configs/neus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72541f22ceb665e8834843f4b38d128f91a264fd --- /dev/null +++ b/configs/neus.yaml @@ -0,0 +1,26 @@ +model: + base_lr: 5.0e-4 + target: renderer.renderer.RendererTrainer + params: + total_steps: 2000 + warm_up_steps: 100 + train_batch_num: 3584 + train_batch_fg_num: 512 + test_batch_num: 4096 + use_mask: true + lambda_rgb_loss: 0.5 + lambda_mask_loss: 1.0 + lambda_eikonal_loss: 0.1 + use_warm_up: true + +data: + target: renderer.dummy_dataset.DummyDataset + params: {} + +callbacks: + save_interval: 500 + +trainer: + val_check_interval: 500 + max_steps: 2000 + diff --git a/configs/syncdreamer-train.yaml b/configs/syncdreamer-train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd2ecc12d19a318a3f8bd4547a8d8ab452643b9d --- /dev/null +++ b/configs/syncdreamer-train.yaml @@ -0,0 +1,63 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.sync_dreamer.SyncMultiviewDiffusion + params: + view_num: 16 + image_size: 256 + cfg_scale: 2.0 + output_num: 8 + batch_view_num: 4 + finetune_unet: false + finetune_projection: false + drop_conditions: false + clip_image_encoder_path: ckpt/ViT-L-14.pt + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 100000 ] + f_start: [ 0.02 ] + f_max: [ 1.0 ] + f_min: [ 1.0 ] + + unet_config: + target: ldm.models.diffusion.sync_dreamer_attention.DepthWiseAttention + params: + volume_dims: [64, 128, 256, 512] + image_size: 32 + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + +data: + target: ldm.data.sync_dreamer.SyncDreamerDataset + params: + target_dir: training_examples/target # renderings of target views + input_dir: training_examples/input # renderings of input views + uid_set_pkl: training_examples/uid_set.pkl # a list of uids + validation_dir: validation_set # directory of validation data + batch_size: 24 # batch size for a single gpu + num_workers: 8 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 1000 # we will save models every 1k steps + callbacks: + {} + + trainer: + benchmark: True + val_check_interval: 1000 # we will run validation every 1k steps, the validation will output images to //val + num_sanity_val_steps: 0 + check_val_every_n_epoch: null diff --git a/configs/syncdreamer.yaml b/configs/syncdreamer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0538633d85d822549294593a940895e778336d40 --- /dev/null +++ b/configs/syncdreamer.yaml @@ -0,0 +1,45 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.sync_dreamer.SyncMultiviewDiffusion + params: + view_num: 16 + image_size: 256 + cfg_scale: 2.0 + output_num: 8 + batch_view_num: 4 + finetune_unet: false + finetune_projection: false + drop_conditions: false + clip_image_encoder_path: ckpt/ViT-L-14.pt + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 100000 ] + f_start: [ 0.02 ] + f_max: [ 1.0 ] + f_min: [ 1.0 ] + + unet_config: + target: ldm.models.diffusion.sync_dreamer_attention.DepthWiseAttention + params: + volume_dims: [64, 128, 256, 512] + image_size: 32 + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + +data: {} + +lightning: + trainer: {} diff --git a/examples/monkey.png b/examples/monkey.png new file mode 100644 index 0000000000000000000000000000000000000000..8436295a6209bc5a12be57a2f9987fe24d88ead2 Binary files /dev/null and b/examples/monkey.png differ diff --git a/generate.py b/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1a30c9909e4f3f977a6cb020239ca143ee599d --- /dev/null +++ b/generate.py @@ -0,0 +1,62 @@ +import argparse +from pathlib import Path + +import numpy as np +import torch +from omegaconf import OmegaConf +from skimage.io import imsave + +from ldm.models.diffusion.sync_dreamer import SyncMultiviewDiffusion +from ldm.util import instantiate_from_config, prepare_inputs + + +def load_model(cfg,ckpt,strict=True): + config = OmegaConf.load(cfg) + model = instantiate_from_config(config.model) + print(f'loading model from {ckpt} ...') + ckpt = torch.load(ckpt,map_location='cpu') + model.load_state_dict(ckpt['state_dict'],strict=strict) + model = model.cuda().eval() + return model + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--cfg',type=str, default='configs/syncdreamer.yaml') + parser.add_argument('--ckpt',type=str, default='ckpt/syncdreamer-step80k.ckpt') + parser.add_argument('--output', type=str, required=True) + parser.add_argument('--input', type=str, required=True) + parser.add_argument('--elevation', type=float, required=True) + + parser.add_argument('--sample_num', type=int, default=4) + parser.add_argument('--crop_size', type=int, default=-1) + parser.add_argument('--cfg_scale', type=float, default=2.0) + parser.add_argument('--batch_view_num', type=int, default=8) + parser.add_argument('--seed', type=int, default=6033) + flags = parser.parse_args() + + torch.random.manual_seed(flags.seed) + np.random.seed(flags.seed) + + model = load_model(flags.cfg, flags.ckpt, strict=True) + assert isinstance(model, SyncMultiviewDiffusion) + Path(f'{flags.output}').mkdir(exist_ok=True, parents=True) + + # prepare data + data = prepare_inputs(flags.input, flags.elevation, flags.crop_size) + for k, v in data.items(): + data[k] = v.unsqueeze(0).cuda() + data[k] = torch.repeat_interleave(data[k], flags.sample_num, dim=0) + x_sample = model.sample(data, flags.cfg_scale, flags.batch_view_num) + + B, N, _, H, W = x_sample.shape + x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5 + x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255 + x_sample = x_sample.astype(np.uint8) + + for bi in range(B): + output_fn = Path(flags.output)/ f'{bi}.png' + imsave(output_fn, np.concatenate([x_sample[bi,ni] for ni in range(N)], 1)) + +if __name__=="__main__": + main() + diff --git a/hf_demo/examples/basket.png b/hf_demo/examples/basket.png new file mode 100644 index 0000000000000000000000000000000000000000..206e8b813984790366dc245728742bda6654e4ee --- /dev/null +++ b/hf_demo/examples/basket.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b7d07f44e1b223b5f3f6e97bf1e64198dbc63a020e860d1ffc177a5a42e7bd9 +size 46071 diff --git a/hf_demo/style.css b/hf_demo/style.css new file mode 100644 index 0000000000000000000000000000000000000000..031f78fdb75e7c517d62f6b9e240828ee4b6a912 --- /dev/null +++ b/hf_demo/style.css @@ -0,0 +1,33 @@ +#model-3d-out { + height: 400px; +} + +#plot-out { + height: 450px; +} + +#duplicate-button { + margin-left: auto; + color: #fff; + background: #1565c0; + } + +.footer { + margin-bottom: 45px; + margin-top: 10px; + text-align: center; + border-bottom: 1px solid #e5e5e5; +} +.footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(15px); + background: white; +} +.dark .footer { + border-color: #303030; +} +.dark .footer>p { + background: #0b0f19; +} \ No newline at end of file diff --git a/ldm/base_utils.py b/ldm/base_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4b6843946aeae1feecccb15a7068111eb47205 --- /dev/null +++ b/ldm/base_utils.py @@ -0,0 +1,158 @@ +import pickle +import numpy as np +import cv2 +from skimage.io import imread + + +def save_pickle(data, pkl_path): + # os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) + with open(pkl_path, 'wb') as f: + pickle.dump(data, f) + +def read_pickle(pkl_path): + with open(pkl_path, 'rb') as f: + return pickle.load(f) + +def draw_epipolar_line(F, img0, img1, pt0, color): + h1,w1=img1.shape[:2] + hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None] + l = F @ hpt + l = l[:, 0] + a, b, c = l[0], l[1], l[2] + pt1 = np.asarray([0, -c / b]).astype(np.int32) + pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32) + + img0 = cv2.circle(img0, tuple(pt0.astype(np.int32)), 5, color, 2) + img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2) + return img0, img1 + +def draw_epipolar_lines(F, img0, img1,num=20): + img0,img1=img0.copy(),img1.copy() + h0, w0, _ = img0.shape + h1, w1, _ = img1.shape + + for k in range(num): + color = np.random.randint(0, 255, [3], dtype=np.int32) + color = [int(c) for c in color] + pt = np.random.uniform(0, 1, 2) + pt[0] *= w0 + pt[1] *= h0 + pt = pt.astype(np.int32) + img0, img1 = draw_epipolar_line(F, img0, img1, pt, color) + + return img0, img1 + +def compute_F(K1, K2, Rt0, Rt1=None): + if Rt1 is None: + R, t = Rt0[:,:3], Rt0[:,3:] + else: + Rt = compute_dR_dt(Rt0,Rt1) + R, t = Rt[:,:3], Rt[:,3:] + A = K1 @ R.T @ t # [3,1] + C = np.asarray([[0,-A[2,0],A[1,0]], + [A[2,0],0,-A[0,0]], + [-A[1,0],A[0,0],0]]) + F = (np.linalg.inv(K2)).T @ R @ K1.T @ C + return F + +def compute_dR_dt(Rt0, Rt1): + R0, t0 = Rt0[:,:3], Rt0[:,3:] + R1, t1 = Rt1[:,:3], Rt1[:,3:] + dR = np.dot(R1, R0.T) + dt = t1 - np.dot(dR, t0) + return np.concatenate([dR, dt], -1) + +def concat_images(img0,img1,vert=False): + if not vert: + h0,h1=img0.shape[0],img1.shape[0], + if h00) + if np.sum(mask0)>0: dpt[mask0]=1e-4 + mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0) + if np.sum(mask1)>0: dpt[mask1]=-1e-4 + pts2d = pts[:,:2]/dpt[:,None] + return pts2d, dpt + + +def draw_keypoints(img, kps, colors=None, radius=2): + out_img=img.copy() + for pi, pt in enumerate(kps): + pt = np.round(pt).astype(np.int32) + if colors is not None: + color=[int(c) for c in colors[pi]] + cv2.circle(out_img, tuple(pt), radius, color, -1) + else: + cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1) + return out_img + + +def output_points(fn,pts,colors=None): + with open(fn, 'w') as f: + for pi, pt in enumerate(pts): + f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ') + if colors is not None: + f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}') + f.write('\n') + +DEPTH_MAX, DEPTH_MIN = 2.4, 0.6 +DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63 +def read_depth_objaverse(depth_fn): + depth = imread(depth_fn) + depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN + mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX) + return depth, mask + + +def mask_depth_to_pts(mask,depth,K,rgb=None): + hs,ws=np.nonzero(mask) + depth=depth[hs,ws] + pts=np.asarray([ws,hs,depth],np.float32).transpose() + pts[:,:2]*=pts[:,2:] + if rgb is not None: + return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws] + else: + return np.dot(pts, np.linalg.inv(K).transpose()) + +def transform_points_pose(pts, pose): + R, t = pose[:, :3], pose[:, 3] + if len(pts.shape)==1: + return (R @ pts[:,None] + t[:,None])[:,0] + return pts @ R.T + t[None,:] + +def pose_apply(pose,pts): + return transform_points_pose(pts, pose) + +def downsample_gaussian_blur(img, ratio): + sigma = (1 / ratio) / 3 + # ksize=np.ceil(2*sigma) + ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1)) + ksize = ksize + 1 if ksize % 2 == 0 else ksize + img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101) + return img \ No newline at end of file diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/base.py b/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..742794e631081bbfa7c44f3df6f83373ca5c15c1 --- /dev/null +++ b/ldm/data/base.py @@ -0,0 +1,40 @@ +import os +import numpy as np +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass + + +class PRNGMixin(object): + """ + Adds a prng property which is a numpy RandomState which gets + reinitialized whenever the pid changes to avoid synchronized sampling + behavior when used in conjunction with multiprocessing. + """ + @property + def prng(self): + currentpid = os.getpid() + if getattr(self, "_initpid", None) != currentpid: + self._initpid = currentpid + self._prng = np.random.RandomState() + return self._prng diff --git a/ldm/data/coco.py b/ldm/data/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5e27e6ec6a51932f67b83dd88533cb39631e26 --- /dev/null +++ b/ldm/data/coco.py @@ -0,0 +1,253 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset +from abc import abstractmethod + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + assert crop_type in [None, 'random', 'center'] + self.crop_type = crop_type + self.use_segmenation = use_segmentation + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", + f"captions_val{self.year()}.json"] + # TODO currently hardcoded paths, would be better to follow logic in + # cocstuff pixelmaps + if self.use_segmenation: + if self.stuffthing: + self.segmentation_prefix = ( + f"data/cocostuffthings/val{self.year()}" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/cocostuffthings/train{self.year()}") + else: + self.segmentation_prefix = ( + f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + if self.use_segmenation: + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + # default option for train is random crop + if self.crop_type in [None, 'random']: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + @abstractmethod + def year(self): + raise NotImplementedError() + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path=None): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if segmentation_path: + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + + image, segmentation = processed["image"], processed["segmentation"] + else: + image = self.preprocessor(image=image,)['image'] + + image = (image / 127.5 - 1.0).astype(np.float32) + if segmentation_path: + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + else: + return image + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + if self.use_segmenation: + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + else: + image = self.preprocess_image(img_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + #"caption": [str(caption[0])], + "caption": caption, + "img_path": img_path, + "filename_": img_path.split(os.sep)[-1] + } + if self.use_segmenation: + example.update({"seg_path": seg_path, 'segmentation': segmentation}) + return example + + +class CocoImagesAndCaptionsTrain2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + def year(self): + return '2017' + + +class CocoImagesAndCaptionsValidation2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" + + def year(self): + return '2017' + + + +class CocoImagesAndCaptionsTrain2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): + super().__init__(size=size, + dataroot="data/coco/train2014", + datajson="data/coco/annotations2014/annotations/captions_train2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "train" + + def year(self): + return '2014' + +class CocoImagesAndCaptionsValidation2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None,crop_type='center',**kwargs): + super().__init__(size=size, + dataroot="data/coco/val2014", + datajson="data/coco/annotations2014/annotations/captions_val2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "validation" + + def year(self): + return '2014' + +if __name__ == '__main__': + with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: + json_data = json.load(json_file) + capdirs = json_data["annotations"] + import pudb; pudb.set_trace() + #d2 = CocoImagesAndCaptionsTrain2014(size=256) + d2 = CocoImagesAndCaptionsValidation2014(size=256) + print("constructed dataset.") + print(f"length of {d2.__class__.__name__}: {len(d2)}") + + ex2 = d2[0] + # ex3 = d3[0] + # print(ex1["image"].shape) + print(ex2["image"].shape) + # print(ex3["image"].shape) + # print(ex1["segmentation"].shape) + print(ex2["caption"].__class__.__name__) diff --git a/ldm/data/dummy.py b/ldm/data/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..3b74a77fe8954686e480d28aaed19e52d3e3c9b7 --- /dev/null +++ b/ldm/data/dummy.py @@ -0,0 +1,34 @@ +import numpy as np +import random +import string +from torch.utils.data import Dataset, Subset + +class DummyData(Dataset): + def __init__(self, length, size): + self.length = length + self.size = size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + letters = string.ascii_lowercase + y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) + return {"jpg": x, "txt": y} + + +class DummyDataWithEmbeddings(Dataset): + def __init__(self, length, size, emb_size): + self.length = length + self.size = size + self.emb_size = emb_size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + y = np.random.randn(*self.emb_size).astype(np.float32) + return {"jpg": x, "txt": y} + diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..66231964a685cc875243018461a6aaa63a96dbf0 --- /dev/null +++ b/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["caption"] = example["human_label"] # dummy caption + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/ldm/data/inpainting/__init__.py b/ldm/data/inpainting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/inpainting/synthetic_mask.py b/ldm/data/inpainting/synthetic_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4c38f3a79b8eb40553469d6f0656ad2f54609a --- /dev/null +++ b/ldm/data/inpainting/synthetic_mask.py @@ -0,0 +1,166 @@ +from PIL import Image, ImageDraw +import numpy as np + +settings = { + "256narrow": { + "p_irr": 1, + "min_n_irr": 4, + "max_n_irr": 50, + "max_l_irr": 40, + "max_w_irr": 10, + "min_n_box": None, + "max_n_box": None, + "min_s_box": None, + "max_s_box": None, + "marg": None, + }, + "256train": { + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 200, + "max_w_irr": 100, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 30, + "max_s_box": 150, + "marg": 10, + }, + "512train": { # TODO: experimental + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 450, + "max_w_irr": 250, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 30, + "max_s_box": 300, + "marg": 10, + }, + "512train-large": { # TODO: experimental + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 450, + "max_w_irr": 400, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 75, + "max_s_box": 450, + "marg": 10, + }, +} + + +def gen_segment_mask(mask, start, end, brush_width): + mask = mask > 0 + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + draw = ImageDraw.Draw(mask) + draw.line([start, end], fill=255, width=brush_width, joint="curve") + mask = np.array(mask) / 255 + return mask + + +def gen_box_mask(mask, masked): + x_0, y_0, w, h = masked + mask[y_0:y_0 + h, x_0:x_0 + w] = 1 + return mask + + +def gen_round_mask(mask, masked, radius): + x_0, y_0, w, h = masked + xy = [(x_0, y_0), (x_0 + w, y_0 + w)] + + mask = mask > 0 + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + draw = ImageDraw.Draw(mask) + draw.rounded_rectangle(xy, radius=radius, fill=255) + mask = np.array(mask) / 255 + return mask + + +def gen_large_mask(prng, img_h, img_w, + marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, + min_n_box, max_n_box, min_s_box, max_s_box): + """ + img_h: int, an image height + img_w: int, an image width + marg: int, a margin for a box starting coordinate + p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask + + min_n_irr: int, min number of segments + max_n_irr: int, max number of segments + max_l_irr: max length of a segment in polygonal chain + max_w_irr: max width of a segment in polygonal chain + + min_n_box: int, min bound for the number of box primitives + max_n_box: int, max bound for the number of box primitives + min_s_box: int, min length of a box side + max_s_box: int, max length of a box side + """ + + mask = np.zeros((img_h, img_w)) + uniform = prng.randint + + if np.random.uniform(0, 1) < p_irr: # generate polygonal chain + n = uniform(min_n_irr, max_n_irr) # sample number of segments + + for _ in range(n): + y = uniform(0, img_h) # sample a starting point + x = uniform(0, img_w) + + a = uniform(0, 360) # sample angle + l = uniform(10, max_l_irr) # sample segment length + w = uniform(5, max_w_irr) # sample a segment width + + # draw segment starting from (x,y) to (x_,y_) using brush of width w + x_ = x + l * np.sin(a) + y_ = y + l * np.cos(a) + + mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) + x, y = x_, y_ + else: # generate Box masks + n = uniform(min_n_box, max_n_box) # sample number of rectangles + + for _ in range(n): + h = uniform(min_s_box, max_s_box) # sample box shape + w = uniform(min_s_box, max_s_box) + + x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box + y_0 = uniform(marg, img_h - marg - h) + + if np.random.uniform(0, 1) < 0.5: + mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) + else: + r = uniform(0, 60) # sample radius + mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) + return mask + + +make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) +make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) +make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) +make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) + + +MASK_MODES = { + "256train": make_lama_mask, + "256narrow": make_narrow_lama_mask, + "512train": make_512_lama_mask, + "512train-large": make_512_lama_mask_large +} + +if __name__ == "__main__": + import sys + + out = sys.argv[1] + + prng = np.random.RandomState(1) + kwargs = settings["256train"] + mask = gen_large_mask(prng, 256, 256, **kwargs) + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + mask.save(out) diff --git a/ldm/data/laion.py b/ldm/data/laion.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb608c1a4cf2b7c0215bdd7c1c81841e3a39b0c --- /dev/null +++ b/ldm/data/laion.py @@ -0,0 +1,537 @@ +import webdataset as wds +import kornia +from PIL import Image +import io +import os +import torchvision +from PIL import Image +import glob +import random +import numpy as np +import pytorch_lightning as pl +from tqdm import tqdm +from omegaconf import OmegaConf +from einops import rearrange +import torch +from webdataset.handlers import warn_and_continue + + +from ldm.util import instantiate_from_config +from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES +from ldm.data.base import PRNGMixin + + +class DataWithWings(torch.utils.data.IterableDataset): + def __init__(self, min_size, transform=None, target_transform=None): + self.min_size = min_size + self.transform = transform if transform is not None else nn.Identity() + self.target_transform = target_transform if target_transform is not None else nn.Identity() + self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee') + self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e') + self.pwatermark_threshold = 0.8 + self.punsafe_threshold = 0.5 + self.aesthetic_threshold = 5. + self.total_samples = 0 + self.samples = 0 + location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -' + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode('pilrgb', handler=wds.warn_and_continue), + wds.map(self._add_tags, handler=wds.ignore_and_continue), + wds.select(self._filter_predicate), + wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue), + wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue), + ) + + @staticmethod + def _compute_hash(url, text): + if url is None: + url = '' + if text is None: + text = '' + total = (url + text).encode('utf-8') + return mmh3.hash64(total)[0] + + def _add_tags(self, x): + hsh = self._compute_hash(x['json']['url'], x['txt']) + pwatermark, punsafe = self.kv[hsh] + aesthetic = self.kv_aesthetic[hsh][0] + return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic} + + def _punsafe_to_class(self, punsafe): + return torch.tensor(punsafe >= self.punsafe_threshold).long() + + def _filter_predicate(self, x): + try: + return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except: + return False + + def __iter__(self): + return iter(self.inner_dataset) + + +def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): + """Take a list of samples (as dictionary) and create a batch, preserving the keys. + If `tensors` is True, `ndarray` objects are combined into + tensor batches. + :param dict samples: list of samples + :param bool tensors: whether to turn lists of ndarrays into a single ndarray + :returns: single sample consisting of a batch + :rtype: dict + """ + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [] for key in keys} + + for s in samples: + [batched[key].append(s[key]) for key in batched] + + result = {} + for key in batched: + if isinstance(batched[key][0], (int, float)): + if combine_scalars: + result[key] = np.array(list(batched[key])) + elif isinstance(batched[key][0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(list(batched[key])) + elif isinstance(batched[key][0], np.ndarray): + if combine_tensors: + result[key] = np.array(list(batched[key])) + else: + result[key] = list(batched[key]) + return result + + +class WebDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, tar_base, batch_size, train=None, validation=None, + test=None, num_workers=4, multinode=True, min_size=None, + max_pwatermark=1.0, + **kwargs): + super().__init__(self) + print(f'Setting tar base to {tar_base}') + self.tar_base = tar_base + self.batch_size = batch_size + self.num_workers = num_workers + self.train = train + self.validation = validation + self.test = test + self.multinode = multinode + self.min_size = min_size # filter out very small images + self.max_pwatermark = max_pwatermark # filter out watermarked images + + def make_loader(self, dataset_config, train=True): + if 'image_transforms' in dataset_config: + image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] + else: + image_transforms = [] + + image_transforms.extend([torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = torchvision.transforms.Compose(image_transforms) + + if 'transforms' in dataset_config: + transforms_config = OmegaConf.to_container(dataset_config.transforms) + else: + transforms_config = dict() + + transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) + if transforms_config[dkey] != 'identity' else identity + for dkey in transforms_config} + img_key = dataset_config.get('image_key', 'jpeg') + transform_dict.update({img_key: image_transforms}) + + if 'postprocess' in dataset_config: + postprocess = instantiate_from_config(dataset_config['postprocess']) + else: + postprocess = None + + shuffle = dataset_config.get('shuffle', 0) + shardshuffle = shuffle > 0 + + nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only + + if self.tar_base == "__improvedaesthetic__": + print("## Warning, loading the same improved aesthetic dataset " + "for all splits and ignoring shards parameter.") + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + else: + tars = os.path.join(self.tar_base, dataset_config.shards) + + dset = wds.WebDataset( + tars, + nodesplitter=nodesplitter, + shardshuffle=shardshuffle, + handler=wds.warn_and_continue).repeat().shuffle(shuffle) + print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') + + dset = (dset + .select(self.filter_keys) + .decode('pil', handler=wds.warn_and_continue) + .select(self.filter_size) + .map_dict(**transform_dict, handler=wds.warn_and_continue) + ) + if postprocess is not None: + dset = dset.map(postprocess) + dset = (dset + .batched(self.batch_size, partial=False, + collation_fn=dict_collation_fn) + ) + + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, + num_workers=self.num_workers) + + return loader + + def filter_size(self, x): + try: + valid = True + if self.min_size is not None and self.min_size > 1: + try: + valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except Exception: + valid = False + if self.max_pwatermark is not None and self.max_pwatermark < 1.0: + try: + valid = valid and x['json']['pwatermark'] <= self.max_pwatermark + except Exception: + valid = False + return valid + except Exception: + return False + + def filter_keys(self, x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def train_dataloader(self): + return self.make_loader(self.train) + + def val_dataloader(self): + return self.make_loader(self.validation, train=False) + + def test_dataloader(self): + return self.make_loader(self.test, train=False) + + +from ldm.modules.image_degradation import degradation_fn_bsr_light +import cv2 + +class AddLR(object): + def __init__(self, factor, output_size, initial_size=None, image_key="jpg"): + self.factor = factor + self.output_size = output_size + self.image_key = image_key + self.initial_size = initial_size + + def pt2np(self, x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x)/127.5-1.0 + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample[self.image_key]) + if self.initial_size is not None: + x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2) + x = degradation_fn_bsr_light(x, sf=self.factor)['image'] + x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2) + x = self.np2pt(x) + sample['lr'] = x + return sample + +class AddBW(object): + def __init__(self, image_key="jpg"): + self.image_key = image_key + + def pt2np(self, x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x)/127.5-1.0 + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample[self.image_key] + w = torch.rand(3, device=x.device) + w /= w.sum() + out = torch.einsum('hwc,c->hw', x, w) + + # Keep as 3ch so we can pass to encoder, also we might want to add hints + sample['lr'] = out.unsqueeze(-1).tile(1,1,3) + return sample + +class AddMask(PRNGMixin): + def __init__(self, mode="512train", p_drop=0.): + super().__init__() + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] + self.p_drop = p_drop + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]): + mask = np.ones_like(mask) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + sample['masked_image'] = x * (mask < 0.5) + return sample + + +class AddEdge(PRNGMixin): + def __init__(self, mode="512train", mask_edges=True): + super().__init__() + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] + self.n_down_choices = [0] + self.sigma_choices = [1, 2] + self.mask_edges = mask_edges + + @torch.no_grad() + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + + n_down_idx = self.prng.choice(len(self.n_down_choices)) + sigma_idx = self.prng.choice(len(self.sigma_choices)) + + n_choices = len(self.n_down_choices)*len(self.sigma_choices) + raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx), + (len(self.n_down_choices), len(self.sigma_choices))) + normalized_idx = raveled_idx/max(1, n_choices-1) + + n_down = self.n_down_choices[n_down_idx] + sigma = self.sigma_choices[sigma_idx] + + kernel_size = 4*sigma+1 + kernel_size = (kernel_size, kernel_size) + sigma = (sigma, sigma) + canny = kornia.filters.Canny( + low_threshold=0.1, + high_threshold=0.2, + kernel_size=kernel_size, + sigma=sigma, + hysteresis=True, + ) + y = (x+1.0)/2.0 # in 01 + y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous() + + # down + for i_down in range(n_down): + size = min(y.shape[-2], y.shape[-1])//2 + y = kornia.geometry.transform.resize(y, size, antialias=True) + + # edge + _, y = canny(y) + + if n_down > 0: + size = x.shape[0], x.shape[1] + y = kornia.geometry.transform.resize(y, size, interpolation="nearest") + + y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous() + y = y*2.0-1.0 + + if self.mask_edges: + sample['masked_image'] = y * (mask < 0.5) + else: + sample['masked_image'] = y + sample['mask'] = torch.zeros_like(sample['mask']) + + # concat normalized idx + sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx + + return sample + + +def example00(): + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" + dataset = wds.WebDataset(url) + example = next(iter(dataset)) + for k in example: + print(k, type(example[k])) + + print(example["__key__"]) + for k in ["json", "txt"]: + print(example[k].decode()) + + image = Image.open(io.BytesIO(example["jpg"])) + outdir = "tmp" + os.makedirs(outdir, exist_ok=True) + image.save(os.path.join(outdir, example["__key__"] + ".png")) + + + def load_example(example): + return { + "key": example["__key__"], + "image": Image.open(io.BytesIO(example["jpg"])), + "text": example["txt"].decode(), + } + + + for i, example in tqdm(enumerate(dataset)): + ex = load_example(example) + print(ex["image"].size, ex["text"]) + if i >= 100: + break + + +def example01(): + # the first laion shards contain ~10k examples each + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -" + + batch_size = 3 + shuffle_buffer = 10000 + dset = wds.WebDataset( + url, + nodesplitter=wds.shardlists.split_by_node, + shardshuffle=True, + ) + dset = (dset + .shuffle(shuffle_buffer, initial=shuffle_buffer) + .decode('pil', handler=warn_and_continue) + .batched(batch_size, partial=False, + collation_fn=dict_collation_fn) + ) + + num_workers = 2 + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers) + + batch_sizes = list() + keys_per_epoch = list() + for epoch in range(5): + keys = list() + for batch in tqdm(loader): + batch_sizes.append(len(batch["__key__"])) + keys.append(batch["__key__"]) + + for bs in batch_sizes: + assert bs==batch_size + print(f"{len(batch_sizes)} batches of size {batch_size}.") + batch_sizes = list() + + keys_per_epoch.append(keys) + for i_batch in [0, 1, -1]: + print(f"Batch {i_batch} of epoch {epoch}:") + print(keys[i_batch]) + print("next epoch.") + + +def example02(): + from omegaconf import OmegaConf + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data import IterableDataset + from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + + #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") + #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml") + config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml") + datamod = WebDataModuleFromConfig(**config["data"]["params"]) + dataloader = datamod.train_dataloader() + + for batch in dataloader: + print(batch.keys()) + print(batch["jpg"].shape) + break + + +def example03(): + # improved aesthetics + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + def filter_watermark(x): + try: + return x['json']['pwatermark'] < 0.5 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + n_save = 20 + n_total = 0 + n_large = 0 + n_large_nowm = 0 + for i, example in enumerate(dataset): + n_total += 1 + if filter_size(example): + n_large += 1 + if filter_watermark(example): + n_large_nowm += 1 + if n_large_nowm < n_save+1: + image = example["jpg"] + image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png")) + + if i%500 == 0: + print(i) + print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%") + if n_large > 0: + print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%") + + + +def example04(): + # improved aesthetics + for i_shard in range(60208)[::-1]: + print(i_shard) + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard) + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + try: + example = next(iter(dataset)) + except Exception: + print(f"Error @ {i_shard}") + + +if __name__ == "__main__": + #example01() + #example02() + example03() + #example04() diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/ldm/data/nerf_like.py b/ldm/data/nerf_like.py new file mode 100644 index 0000000000000000000000000000000000000000..84ef18288db005c72d3b5832144a7bd5cfffe9b2 --- /dev/null +++ b/ldm/data/nerf_like.py @@ -0,0 +1,165 @@ +from torch.utils.data import Dataset +import os +import json +import numpy as np +import torch +import imageio +import math +import cv2 +from torchvision import transforms + +def cartesian_to_spherical(xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + +def get_T(T_target, T_cond): + theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_T + +def get_spherical(T_target, T_cond): + theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()]) + return d_T + +class RTMV(Dataset): + def __init__(self, root_dir='datasets/RTMV/google_scanned',\ + first_K=64, resolution=256, load_target=False): + self.root_dir = root_dir + self.scene_list = sorted(next(os.walk(root_dir))[1]) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms.json'), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path) + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img + + +class GSO(Dataset): + def __init__(self, root_dir='datasets/GoogleScannedObjects',\ + split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'): + self.root_dir = root_dir + with open(os.path.join(root_dir, '%s.json' % split), "r") as f: + self.scene_list = json.load(f) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + self.name = name + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path) + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + mask = imgs[:, :, :, -1] + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img + +class WILD(Dataset): + def __init__(self, root_dir='data/nerf_wild',\ + first_K=33, resolution=256, load_target=False): + self.root_dir = root_dir + self.scene_list = sorted(next(os.walk(root_dir))[1]) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path + '.png') + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img \ No newline at end of file diff --git a/ldm/data/simple.py b/ldm/data/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..9b48e8859047234a4ca3bd44544e647178dadec9 --- /dev/null +++ b/ldm/data/simple.py @@ -0,0 +1,526 @@ +from typing import Dict +import webdataset as wds +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +import torchvision +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset +import pytorch_lightning as pl +import copy +import csv +import cv2 +import random +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +import json +import os +import webdataset as wds +import math +from torch.utils.data.distributed import DistributedSampler + +# Some hacky things to make experimentation easier +def make_transform_multi_folder_data(paths, caption_files=None, **kwargs): + ds = make_multi_folder_data(paths, caption_files, **kwargs) + return TransformDataset(ds) + +def make_nfp_data(base_path): + dirs = list(Path(base_path).glob("*/")) + print(f"Found {len(dirs)} folders") + print(dirs) + tforms = [transforms.Resize(512), transforms.CenterCrop(512)] + datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs] + return torch.utils.data.ConcatDataset(datasets) + + +class VideoDataset(Dataset): + def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2): + self.root_dir = Path(root_dir) + self.caption_file = caption_file + self.n = n + ext = "mp4" + self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) + self.offset = offset + + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + with open(self.caption_file) as f: + reader = csv.reader(f) + rows = [row for row in reader] + self.captions = dict(rows) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + for i in range(10): + try: + return self._load_sample(index) + except Exception: + # Not really good enough but... + print("uh oh") + + def _load_sample(self, index): + n = self.n + filename = self.paths[index] + min_frame = 2*self.offset + 2 + vid = cv2.VideoCapture(str(filename)) + max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) + curr_frame_n = random.randint(min_frame, max_frames) + vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n) + _, curr_frame = vid.read() + + prev_frames = [] + for i in range(n): + prev_frame_n = curr_frame_n - (i+1)*self.offset + vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n) + _, prev_frame = vid.read() + prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1])) + prev_frames.append(prev_frame) + + vid.release() + caption = self.captions[filename.name] + data = { + "image": self.tform(Image.fromarray(curr_frame[...,::-1])), + "prev": torch.cat(prev_frames, dim=-1), + "txt": caption + } + return data + +# end hacky things + + +def make_tranforms(image_transforms): + # if isinstance(image_transforms, ListConfig): + # image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + return image_transforms + + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + + + +class NfpDataset(Dataset): + def __init__(self, + root_dir, + image_transforms=[], + ext="jpg", + default_caption="", + ) -> None: + """assume sequential frames and a deterministic transform""" + + self.root_dir = Path(root_dir) + self.default_caption = default_caption + + self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) + self.tform = make_tranforms(image_transforms) + + def __len__(self): + return len(self.paths) - 1 + + + def __getitem__(self, index): + prev = self.paths[index] + curr = self.paths[index+1] + data = {} + data["image"] = self._load_im(curr) + data["prev"] = self._load_im(prev) + data["txt"] = self.default_caption + return data + + def _load_im(self, filename): + im = Image.open(filename).convert("RGB") + return self.tform(im) + +class ObjaverseDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, root_dir, batch_size, total_view, train=None, validation=None, + test=None, num_workers=4, **kwargs): + super().__init__(self) + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.total_view = total_view + + if train is not None: + dataset_config = train + if validation is not None: + dataset_config = validation + + if 'image_transforms' in dataset_config: + image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)] + else: + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + self.image_transforms = torchvision.transforms.Compose(image_transforms) + + + def train_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + def test_dataloader(self): + return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\ + batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='.objaverse/hf-objaverse-v1/views', + image_transforms=[], + ext="png", + default_trans=torch.zeros(3), + postprocess=None, + return_paths=False, + total_view=4, + validation=False + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_trans = default_trans + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + self.total_view = total_view + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + with open(os.path.join(root_dir, 'valid_paths.json')) as f: + self.paths = json.load(f) + + total_objects = len(self.paths) + if validation: + self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation + else: + self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training + print('============= length of dataset %d =============' % len(self.paths)) + self.tform = image_transforms + + def __len__(self): + return len(self.paths) + + def cartesian_to_spherical(self, xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + def get_T(self, target_RT, cond_RT): + R, T = target_RT[:3, :3], target_RT[:, -1] + T_target = -R.T @ T + + R, T = cond_RT[:3, :3], cond_RT[:, -1] + T_cond = -R.T @ T + + theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_T + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + try: + img = plt.imread(path) + except: + print(path) + sys.exit() + img[img[:, :, -1] == 0.] = color + img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) + return img + + def __getitem__(self, index): + + data = {} + if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice + total_view = 8 + else: + total_view = 4 + index_target, index_cond = random.sample(range(total_view), 2) # without replacement + filename = os.path.join(self.root_dir, self.paths[index]) + + # print(self.paths[index]) + + if self.return_paths: + data["path"] = str(filename) + + color = [1., 1., 1., 1.] + + try: + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + except: + # very hacky solution, sorry about this + filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + target_im = torch.zeros_like(target_im) + cond_im = torch.zeros_like(cond_im) + + data["image_target"] = target_im + data["image_cond"] = cond_im + data["T"] = self.get_T(target_RT, cond_RT) + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}")))) + self.tform = make_tranforms(image_transforms) + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename).convert("RGB") + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) +import random + +class TransformDataset(): + def __init__(self, ds, extra_label="sksbspic"): + self.ds = ds + self.extra_label = extra_label + self.transforms = { + "align": transforms.Resize(768), + "centerzoom": transforms.CenterCrop(768), + "randzoom": transforms.RandomCrop(768), + } + + + def __getitem__(self, index): + data = self.ds[index] + + im = data['image'] + im = im.permute(2,0,1) + # In case data is smaller than expected + im = transforms.Resize(1024)(im) + + tform_name = random.choice(list(self.transforms.keys())) + im = self.transforms[tform_name](im) + + im = im.permute(1,2,0) + + data['image'] = im + data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}" + + return data + + def __len__(self): + return len(self.ds) + +def hf_dataset( + name, + image_transforms=[], + image_column="image", + text_column="text", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + tform = make_tranforms(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + processed[caption_key] = examples[text_column] + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] + + + +import random +import json +class IdRetreivalDataset(FolderData): + def __init__(self, ret_file, *args, **kwargs): + super().__init__(*args, **kwargs) + with open(ret_file, "rt") as f: + self.ret = json.load(f) + + def __getitem__(self, index): + data = super().__getitem__(index) + key = self.paths[index].name + matches = self.ret[key] + if len(matches) > 0: + retreived = random.choice(matches) + else: + retreived = key + filename = self.root_dir/retreived + im = Image.open(filename).convert("RGB") + im = self.process_im(im) + # data["match"] = im + data["match"] = torch.cat((data["image"], im), dim=-1) + return data diff --git a/ldm/data/sync_dreamer.py b/ldm/data/sync_dreamer.py new file mode 100644 index 0000000000000000000000000000000000000000..df74e6c2c9b5b16866b5fc1a8caf2371f7bdb2ee --- /dev/null +++ b/ldm/data/sync_dreamer.py @@ -0,0 +1,132 @@ +import pytorch_lightning as pl +import numpy as np +import torch +import PIL +import os +from skimage.io import imread +import webdataset as wds +import PIL.Image as Image +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from pathlib import Path + +from ldm.base_utils import read_pickle, pose_inverse +import torchvision.transforms as transforms +import torchvision +from einops import rearrange + +from ldm.util import prepare_inputs + + +class SyncDreamerTrainData(Dataset): + def __init__(self, target_dir, input_dir, uid_set_pkl, image_size=256): + self.default_image_size = 256 + self.image_size = image_size + self.target_dir = Path(target_dir) + self.input_dir = Path(input_dir) + + self.uids = read_pickle(uid_set_pkl) + print('============= length of dataset %d =============' % len(self.uids)) + + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + self.image_transforms = torchvision.transforms.Compose(image_transforms) + self.num_images = 16 + + def __len__(self): + return len(self.uids) + + def load_im(self, path): + img = imread(path) + img = img.astype(np.float32) / 255.0 + mask = img[:,:,3:] + img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background + img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) + return img, mask + + def process_im(self, im): + im = im.convert("RGB") + im = im.resize((self.image_size, self.image_size), resample=PIL.Image.BICUBIC) + return self.image_transforms(im) + + def load_index(self, filename, index): + img, _ = self.load_im(os.path.join(filename, '%03d.png' % index)) + img = self.process_im(img) + return img + + def get_data_for_index(self, index): + target_dir = os.path.join(self.target_dir, self.uids[index]) + input_dir = os.path.join(self.input_dir, self.uids[index]) + + views = np.arange(0, self.num_images) + start_view_index = np.random.randint(0, self.num_images) + views = (views + start_view_index) % self.num_images + + target_images = [] + for si, target_index in enumerate(views): + img = self.load_index(target_dir, target_index) + target_images.append(img) + target_images = torch.stack(target_images, 0) + input_img = self.load_index(input_dir, start_view_index) + + K, azimuths, elevations, distances, cam_poses = read_pickle(os.path.join(input_dir, f'meta.pkl')) + input_elevation = torch.from_numpy(elevations[start_view_index:start_view_index+1].astype(np.float32)) + return {"target_image": target_images, "input_image": input_img, "input_elevation": input_elevation} + + def __getitem__(self, index): + data = self.get_data_for_index(index) + return data + +class SyncDreamerEvalData(Dataset): + def __init__(self, image_dir): + self.image_size = 256 + self.image_dir = Path(image_dir) + self.crop_size = 20 + + self.fns = [] + for fn in Path(image_dir).iterdir(): + if fn.suffix=='.png': + self.fns.append(fn) + print('============= length of dataset %d =============' % len(self.fns)) + + def __len__(self): + return len(self.fns) + + def get_data_for_index(self, index): + input_img_fn = self.fns[index] + elevation = int(Path(input_img_fn).stem.split('-')[-1]) + return prepare_inputs(input_img_fn, elevation, 200) + + def __getitem__(self, index): + return self.get_data_for_index(index) + +class SyncDreamerDataset(pl.LightningDataModule): + def __init__(self, target_dir, input_dir, validation_dir, batch_size, uid_set_pkl, image_size=256, num_workers=4, seed=0, **kwargs): + super().__init__() + self.target_dir = target_dir + self.input_dir = input_dir + self.validation_dir = validation_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.uid_set_pkl = uid_set_pkl + self.seed = seed + self.additional_args = kwargs + self.image_size = image_size + + def setup(self, stage): + if stage in ['fit']: + self.train_dataset = SyncDreamerTrainData(self.target_dir, self.input_dir, uid_set_pkl=self.uid_set_pkl, image_size=256) + self.val_dataset = SyncDreamerEvalData(image_dir=self.validation_dir) + else: + raise NotImplementedError + + def train_dataloader(self): + sampler = DistributedSampler(self.train_dataset, seed=self.seed) + return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + return loader + + def test_dataloader(self): + return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/sync_dreamer.py b/ldm/models/diffusion/sync_dreamer.py new file mode 100644 index 0000000000000000000000000000000000000000..16d0cbafa38fe808ad1acc1a0b463f61e545f3b2 --- /dev/null +++ b/ldm/models/diffusion/sync_dreamer.py @@ -0,0 +1,661 @@ +from pathlib import Path + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from skimage.io import imsave +from torch.optim.lr_scheduler import LambdaLR +from tqdm import tqdm + +from ldm.base_utils import read_pickle, concat_images_list +from ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume +from ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet +from ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding +from ldm.modules.encoders.modules import FrozenCLIPImageEmbedder +from ldm.util import instantiate_from_config + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + +def disable_training_module(module: nn.Module): + module = module.eval() + module.train = disabled_train + for para in module.parameters(): + para.requires_grad = False + return module + +def repeat_to_batch(tensor, B, VN): + t_shape = tensor.shape + ones = [1 for _ in range(len(t_shape)-1)] + tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:]) + return tensor_new + +class UNetWrapper(nn.Module): + def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.drop_conditions = drop_conditions + self.drop_scheme=drop_scheme + self.use_zero_123 = use_zero_123 + + + def drop(self, cond, mask): + shape = cond.shape + B = shape[0] + cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond + return cond + + def get_trainable_parameters(self): + return self.diffusion_model.get_trainable_parameters() + + def get_drop_scheme(self, B, device): + if self.drop_scheme=='default': + random = torch.rand(B, dtype=torch.float32, device=device) + drop_clip = (random > 0.15) & (random <= 0.2) + drop_volume = (random > 0.1) & (random <= 0.15) + drop_concat = (random > 0.05) & (random <= 0.1) + drop_all = random <= 0.05 + else: + raise NotImplementedError + return drop_clip, drop_volume, drop_concat, drop_all + + def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False): + """ + + @param x: B,4,H,W + @param t: B, + @param clip_embed: B,M,768 + @param volume_feats: B,C,D,H,W + @param x_concat: B,C,H,W + @param is_train: + @return: + """ + if self.drop_conditions and is_train: + B = x.shape[0] + drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device) + + clip_mask = 1.0 - (drop_clip | drop_all).float() + clip_embed = self.drop(clip_embed, clip_mask) + + volume_mask = 1.0 - (drop_volume | drop_all).float() + for k, v in volume_feats.items(): + volume_feats[k] = self.drop(v, mask=volume_mask) + + concat_mask = 1.0 - (drop_concat | drop_all).float() + x_concat = self.drop(x_concat, concat_mask) + + if self.use_zero_123: + # zero123 does not multiply this when encoding, maybe a bug for zero123 + first_stage_scale_factor = 0.18215 + x_concat_ = x_concat * 1.0 + x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor + else: + x_concat_ = x_concat + + x = torch.cat([x, x_concat_], 1) + pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats) + return pred + + def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale): + x_ = torch.cat([x] * 2, 0) + t_ = torch.cat([t] * 2, 0) + clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0) + + v_ = {} + for k, v in volume_feats.items(): + v_[k] = torch.cat([v, torch.zeros_like(v)], 0) + + x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0) + if self.use_zero_123: + # zero123 does not multiply this when encoding, maybe a bug for zero123 + first_stage_scale_factor = 0.18215 + x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor + x_ = torch.cat([x_, x_concat_], 1) + s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2) + s = s_uc + unconditional_scale * (s - s_uc) + return s + + +class SpatialVolumeNet(nn.Module): + def __init__(self, time_dim, view_dim, view_num, + input_image_size=256, frustum_volume_depth=48, + spatial_volume_size=32, spatial_volume_length=0.5, + frustum_volume_length=0.86603 # sqrt(3)/2 + ): + super().__init__() + self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16) + self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512)) + self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512)) + + self.frustum_volume_length = frustum_volume_length + self.input_image_size = input_image_size + self.spatial_volume_size = spatial_volume_size + self.spatial_volume_length = spatial_volume_length + + self.frustum_volume_size = self.input_image_size // 8 + self.frustum_volume_depth = frustum_volume_depth + self.time_dim = time_dim + self.view_dim = view_dim + self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin + + def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks): + """ + @param x: B,N,4,H,W + @param t_embed: B,t_dim + @param v_embed: B,N,v_dim + @param target_poses: N,3,4 + @param target_Ks: N,3,3 + @return: + """ + B, N, _, H, W = x.shape + V = self.spatial_volume_size + device = x.device + + spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device) + spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1) + spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)] + spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1) + + # encode source features + t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim) + # v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim) + v_embed_ = v_embed + target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1) + target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1) + + # extract 2D image features + spatial_volume_feats = [] + # project source features + for ni in range(0, N): + pose_source_ = target_poses[:, ni] + K_source_ = target_Ks[:, ni] + x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni]) + C = x_.shape[1] + + coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2) + unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True) + unproj_feats_ = unproj_feats_.view(B, C, V, V, V) + spatial_volume_feats.append(unproj_feats_) + + spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V + N = spatial_volume_feats.shape[1] + spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V) + + spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32 + return spatial_volume_feats + + def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices): + """ + @param spatial_volume: B,C,V,V,V + @param t_embed: B,t_dim + @param v_embed: B,N,v_dim + @param poses: N,3,4 + @param Ks: N,3,3 + @param target_indices: B,TN + @return: B*TN,C,H,W + """ + B, TN = target_indices.shape + H, W = self.frustum_volume_size, self.frustum_volume_size + D = self.frustum_volume_depth + V = self.spatial_volume_size + + near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length + far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length + + target_indices = target_indices.view(B*TN) # B*TN + poses_ = poses[target_indices] # B*TN,3,4 + Ks_ = Ks[target_indices] # B*TN,3,4 + volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W + + volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length] + volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3 + spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V) + volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W + + v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN + t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1) + volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_) + return volume_feats_dict, volume_depth + +class SyncMultiviewDiffusion(pl.LightningModule): + def __init__(self, unet_config, scheduler_config, + finetune_unet=False, finetune_projection=True, + view_num=16, image_size=256, + cfg_scale=3.0, output_num=8, batch_view_num=4, + drop_conditions=False, drop_scheme='default', + clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"): + super().__init__() + + self.finetune_unet = finetune_unet + self.finetune_projection = finetune_projection + + self.view_num = view_num + self.viewpoint_dim = 4 + self.output_num = output_num + self.image_size = image_size + + self.batch_view_num = batch_view_num + self.cfg_scale = cfg_scale + + self.clip_image_encoder_path = clip_image_encoder_path + + self._init_time_step_embedding() + self._init_first_stage() + self._init_schedule() + self._init_multiview() + self._init_clip_image_encoder() + self._init_clip_projection() + + self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num) + self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme) + self.scheduler_config = scheduler_config + + latent_size = image_size//8 + self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size) + + def _init_clip_projection(self): + self.cc_projection = nn.Linear(772, 768) + nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) + nn.init.zeros_(list(self.cc_projection.parameters())[1]) + self.cc_projection.requires_grad_(True) + + if not self.finetune_projection: + disable_training_module(self.cc_projection) + + def _init_multiview(self): + K, azs, _, _, poses = read_pickle(f'meta_info/camera-{self.view_num}.pkl') + default_image_size = 256 + ratio = self.image_size/default_image_size + K = np.diag([ratio,ratio,1]) @ K + K = torch.from_numpy(K.astype(np.float32)) # [3,3] + K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3 + poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4 + self.register_buffer('poses', poses) + self.register_buffer('Ks', K) + azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0 + self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32))) + + def get_viewpoint_embedding(self, batch_size, elevation_ref): + """ + @param batch_size: + @param elevation_ref: B + @return: + """ + azimuth_input = self.azimuth[0].unsqueeze(0) # 1 + azimuth_target = self.azimuth # N + elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!! + elevation_target = -np.deg2rad(30) + d_e = elevation_target - elevation_input # B + N = self.azimuth.shape[0] + B = batch_size + d_e = d_e.unsqueeze(1).repeat(1, N) + d_a = azimuth_target - azimuth_input # N + d_a = d_a.unsqueeze(0).repeat(B, 1) + d_z = torch.zeros_like(d_a) + embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4 + return embedding + + def _init_first_stage(self): + first_stage_config={ + "target": "ldm.models.autoencoder.AutoencoderKL", + "params": { + "embed_dim": 4, + "monitor": "val/rec_loss", + "ddconfig":{ + "double_z": True, + "z_channels": 4, + "resolution": self.image_size, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1,2,4,4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0 + }, + "lossconfig": {"target": "torch.nn.Identity"}, + } + } + self.first_stage_scale_factor = 0.18215 + self.first_stage_model = instantiate_from_config(first_stage_config) + self.first_stage_model = disable_training_module(self.first_stage_model) + + def _init_clip_image_encoder(self): + self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path) + self.clip_image_encoder = disable_training_module(self.clip_image_encoder) + + def _init_schedule(self): + self.num_timesteps = 1000 + linear_start = 0.00085 + linear_end = 0.0120 + num_timesteps = 1000 + betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T + assert betas.shape[0] == self.num_timesteps + + # all in float64 first + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) # T + alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T + posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20)) + posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10) + + self.register_buffer("betas", betas.float()) + self.register_buffer("alphas", alphas.float()) + self.register_buffer("alphas_cumprod", alphas_cumprod.float()) + self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float()) + self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float()) + self.register_buffer("posterior_variance", posterior_variance.float()) + self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float()) + + def _init_time_step_embedding(self): + self.time_embed_dim = 256 + self.time_embed = nn.Sequential( + nn.Linear(self.time_embed_dim, self.time_embed_dim), + nn.SiLU(True), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + + def encode_first_stage(self, x, sample=True): + with torch.no_grad(): + posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8 + if sample: + return posterior.sample().detach() * self.first_stage_scale_factor + else: + return posterior.mode().detach() * self.first_stage_scale_factor + + def decode_first_stage(self, z): + with torch.no_grad(): + z = 1. / self.first_stage_scale_factor * z + return self.first_stage_model.decode(z) + + def prepare(self, batch): + # encode target + if 'target_image' in batch: + image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w + N = image_target.shape[1] + x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)] + x = torch.stack(x, 1) # b,n,4,h//8,w//8 + else: + x = None + + image_input = batch['input_image'].permute(0, 3, 1, 2) + elevation_input = batch['input_elevation'][:, 0] # b + x_input = self.encode_first_stage(image_input) + input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input} + with torch.no_grad(): + clip_embed = self.clip_image_encoder.encode(image_input) + return x, clip_embed, input_info + + def embed_time(self, t): + t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED + t_embed = self.time_embed(t_embed) # B,TED + return t_embed + + def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index): + """ + @param x_input: B,4,H,W + @param spatial_volume: B,C,V,V,V + @param clip_embed: B,1,768 + @param t_embed: B,t_dim + @param v_embed: B,N,v_dim + @param target_index: B,TN + @return: + tensors of size B*TN,* + """ + B, _, H, W = x_input.shape + frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index) + + # clip + TN = target_index.shape[1] + v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim + clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768) + clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768 + + x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W) + + x_concat = x_input_ + return clip_embed_, frustum_volume_feats, x_concat + + def training_step(self, batch): + B = batch['target_image'].shape[0] + time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long() + + x, clip_embed, input_info = self.prepare(batch) + x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W + + N = self.view_num + target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1 + v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim + + t_embed = self.embed_time(time_steps) + spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks) + + clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index) + + x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W + noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W + + noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W + # loss simple for diffusion + loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none') + loss = loss_simple.mean() + self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True) + + # log others + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) + self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) + return loss + + def add_noise(self, x_start, t): + """ + @param x_start: B,* + @param t: B, + @return: + """ + B = x_start.shape[0] + noise = torch.randn_like(x_start) # B,* + + sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B, + sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B + sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) + sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)]) + x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise + return x_noisy, noise + + def sample(self, sampler, batch, cfg_scale, batch_view_num, return_inter_results=False, inter_interval=50, inter_view_interval=2): + _, clip_embed, input_info = self.prepare(batch) + x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num) + + N = x_sample.shape[1] + x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1) + if return_inter_results: + torch.cuda.synchronize() + torch.cuda.empty_cache() + inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W + B,N,T,C,H,W = inter.shape + inter_results = [] + for ni in tqdm(range(0, N, inter_view_interval)): + inter_results_ = [] + for ti in range(T): + inter_results_.append(self.decode_first_stage(inter[:, ni, ti])) + inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W + inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W + return x_sample, inter_results + else: + return x_sample + + def log_image(self, x_sample, batch, step, output_dir): + process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8) + B = x_sample.shape[0] + N = x_sample.shape[1] + image_cond = [] + for bi in range(B): + img_pr_ = concat_images_list(process(batch['input_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)]) + image_cond.append(img_pr_) + + output_dir = Path(output_dir) + imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True)) + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + if batch_idx==0 and self.global_rank==0: + self.eval() + step = self.global_step + batch_ = {} + for k, v in batch.items(): batch_[k] = v[:self.output_num] + x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num) + output_dir = Path(self.image_dir) / 'images' / 'val' + output_dir.mkdir(exist_ok=True, parents=True) + self.log_image(x_sample, batch, step, output_dir=output_dir) + + def configure_optimizers(self): + lr = self.learning_rate + print(f'setting learning rate to {lr:.4f} ...') + paras = [] + if self.finetune_projection: + paras.append({"params": self.cc_projection.parameters(), "lr": lr},) + if self.finetune_unet: + paras.append({"params": self.model.parameters(), "lr": lr},) + else: + paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},) + + paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},) + paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},) + + opt = torch.optim.AdamW(paras, lr=lr) + + scheduler = instantiate_from_config(self.scheduler_config) + print("Setting up LambdaLR scheduler...") + scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + return [opt], scheduler + +class SyncDDIMSampler: + def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=1.0, latent_size=32): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.latent_size = latent_size + self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta) + self.eta = ddim_eta + + def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT + ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT + + alphas_cumprod = self.model.alphas_cumprod # T + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT + self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT + self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev)) + + self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT + self.ddim_sigmas = self.ddim_sigmas.float() + self.ddim_alphas = self.ddim_alphas.float() + self.ddim_alphas_prev = self.ddim_alphas_prev.float() + self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float() + + + @torch.no_grad() + def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False): + """ + @param x_target_noisy: B,N,4,H,W + @param index: index + @param noise_pred: B,N,4,H,W + @param is_step0: bool + @return: + """ + device = x_target_noisy.device + B,N,_,H,W = x_target_noisy.shape + + # apply noise + a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1) + a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1) + sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1) + sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1) + + pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() + dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + if not is_step0: + noise = sigma_t * torch.randn_like(x_target_noisy) + x_prev = x_prev + noise + return x_prev + + @torch.no_grad() + def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False): + """ + @param x_target_noisy: B,N,4,H,W + @param input_info: + @param clip_embed: B,M,768 + @param time_steps: B, + @param index: int + @param unconditional_scale: + @param batch_view_num: int + @param is_step0: bool + @return: + """ + x_input, elevation_input = input_info['x'], input_info['elevation'] + B, N, C, H, W = x_target_noisy.shape + + # construct source data + v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim + t_embed = self.model.embed_time(time_steps) # B,t_dim + spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks) + + e_t = [] + target_indices = torch.arange(N) # N + for ni in range(0, N, batch_view_num): + x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num] + VN = x_target_noisy_.shape[1] + x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W) + + time_steps_ = repeat_to_batch(time_steps, B, VN) + target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1) + clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_) + if unconditional_scale!=1.0: + noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale) + else: + noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False) + e_t.append(noise.view(B,VN,4,H,W)) + + e_t = torch.cat(e_t, 1) + x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0) + return x_prev + + @torch.no_grad() + def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1): + """ + @param input_info: x, elevation + @param clip_embed: B,M,768 + @param unconditional_scale: + @param log_every_t: + @param batch_view_num: + @return: + """ + print(f"unconditional scale {unconditional_scale:.1f}") + C, H, W = 4, self.latent_size, self.latent_size + B = clip_embed.shape[0] + N = self.model.view_num + device = self.model.device + x_target_noisy = torch.randn([B, N, C, H, W], device=device) + + timesteps = self.ddim_timesteps + intermediates = {'x_inter': []} + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 # index in ddim state + time_steps = torch.full((B,), step, device=device, dtype=torch.long) + x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0) + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(x_target_noisy) + + return x_target_noisy, intermediates \ No newline at end of file diff --git a/ldm/models/diffusion/sync_dreamer_attention.py b/ldm/models/diffusion/sync_dreamer_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..11a7d870ef33e46e60383031218117e82857f579 --- /dev/null +++ b/ldm/models/diffusion/sync_dreamer_attention.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn + +from ldm.modules.attention import default, zero_module, checkpoint +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.modules.diffusionmodules.util import timestep_embedding + +class DepthAttention(nn.Module): + def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False) + self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) + self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False) + if output_bias: + self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1) + else: + self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False) + + def forward(self, x, context): + """ + + @param x: b,f0,h,w + @param context: b,f1,d,h,w + @return: + """ + hn, hd = self.heads, self.dim_head + b, _, h, w = x.shape + b, _, d, h, w = context.shape + + q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w + k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w + v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w + + sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w + attn = sim.softmax(dim=2) + + # b,hn,hd,d,h,w * b,hn,1,d,h,w + out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w + out = out.reshape(b,hn*hd,h,w) + return self.to_out(out) + + +class DepthTransformer(nn.Module): + def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): + super().__init__() + inner_dim = n_heads * d_head + self.proj_in = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1), + nn.GroupNorm(8, inner_dim), + nn.SiLU(True), + ) + self.proj_context = nn.Sequential( + nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias + nn.GroupNorm(8, context_dim), + nn.ReLU(True), # only relu, because we want input is 0, output is 0 + ) + self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn + self.proj_out = nn.Sequential( + nn.GroupNorm(8, inner_dim), + nn.ReLU(True), + nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False), + nn.GroupNorm(8, inner_dim), + nn.ReLU(True), + zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)), + ) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context): + x_in = x + x = self.proj_in(x) + context = self.proj_context(context) + x = self.depth_attn(x, context) + x = self.proj_out(x) + x_in + return x + + +class DepthWiseAttention(UNetModel): + def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs): + super().__init__(*args, **kwargs) + # num_heads = 4 + model_channels = kwargs['model_channels'] + channel_mult = kwargs['channel_mult'] + d0,d1,d2,d3 = volume_dims + + # 4 + ch = model_channels*channel_mult[2] + self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3) + + self.output_conditions=nn.ModuleList() + self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8} + # 8 + ch = model_channels*channel_mult[2] + self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0 + self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1 + # 16 + self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2 + ch = model_channels*channel_mult[1] + self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3 + self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4 + # 32 + self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5 + ch = model_channels*channel_mult[0] + self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6 + self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7 + self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8 + + def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for index, module in enumerate(self.input_blocks): + h = module(h, emb, context) + hs.append(h) + + h = self.middle_block(h, emb, context) + h = self.middle_conditions(h, context=source_dict[h.shape[-1]]) + + for index, module in enumerate(self.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if index in self.output_b2c: + layer = self.output_conditions[self.output_b2c[index]] + h = layer(h, context=source_dict[h.shape[-1]]) + + h = h.type(x.dtype) + return self.out(h) + + def get_trainable_parameters(self): + paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()] + return paras diff --git a/ldm/models/diffusion/sync_dreamer_network.py b/ldm/models/diffusion/sync_dreamer_network.py new file mode 100644 index 0000000000000000000000000000000000000000..c03b3ddfba02781beb0a196f55472567e55ac627 --- /dev/null +++ b/ldm/models/diffusion/sync_dreamer_network.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn + +class Image2DResBlockWithTV(nn.Module): + def __init__(self, dim, tdim, vdim): + super().__init__() + norm = lambda c: nn.GroupNorm(8, c) + self.time_embed = nn.Conv2d(tdim, dim, 1, 1) + self.view_embed = nn.Conv2d(vdim, dim, 1, 1) + self.conv = nn.Sequential( + norm(dim), + nn.SiLU(True), + nn.Conv2d(dim, dim, 3, 1, 1), + norm(dim), + nn.SiLU(True), + nn.Conv2d(dim, dim, 3, 1, 1), + ) + + def forward(self, x, t, v): + return x+self.conv(x+self.time_embed(t)+self.view_embed(v)) + + +class NoisyTargetViewEncoder(nn.Module): + def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8): + super().__init__() + + self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1) + self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) + self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) + self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim) + self.final_out = nn.Sequential( + nn.GroupNorm(8, run_dim), + nn.SiLU(True), + nn.Conv2d(run_dim, output_dim, 3, 1, 1) + ) + + def forward(self, x, t, v): + B, DT = t.shape + t = t.view(B, DT, 1, 1) + B, DV = v.shape + v = v.view(B, DV, 1, 1) + + x = self.init_conv(x) + x = self.out_conv0(x, t, v) + x = self.out_conv1(x, t, v) + x = self.out_conv2(x, t, v) + x = self.final_out(x) + return x + +class SpatialUpTimeBlock(nn.Module): + def __init__(self, x_in_dim, t_in_dim, out_dim): + super().__init__() + norm_act = lambda c: nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 + self.norm = norm_act(x_in_dim) + self.silu = nn.SiLU(True) + self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) + + def forward(self, x, t): + x = x + self.t_conv(t) + return self.conv(self.silu(self.norm(x))) + +class SpatialTimeBlock(nn.Module): + def __init__(self, x_in_dim, t_in_dim, out_dim, stride): + super().__init__() + norm_act = lambda c: nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16 + self.bn = norm_act(x_in_dim) + self.silu = nn.SiLU(True) + self.conv = nn.Conv3d(x_in_dim, out_dim, 3, stride=stride, padding=1) + + def forward(self, x, t): + x = x + self.t_conv(t) + return self.conv(self.silu(self.bn(x))) + +class SpatialTime3DNet(nn.Module): + def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)): + super().__init__() + d0, d1, d2, d3 = dims + dt = time_dim + + self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32 + self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1) + + self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2) + self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1) + self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1) + + self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2) + self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1) + self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1) + + self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2) + self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1) + self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1) + + self.conv7 = SpatialUpTimeBlock(d3, dt, d2) + self.conv8 = SpatialUpTimeBlock(d2, dt, d1) + self.conv9 = SpatialUpTimeBlock(d1, dt, d0) + + def forward(self, x, t): + B, C = t.shape + t = t.view(B, C, 1, 1, 1) + + x = self.init_conv(x) + conv0 = self.conv0(x, t) + + x = self.conv1(conv0, t) + x = self.conv2_0(x, t) + conv2 = self.conv2_1(x, t) + + x = self.conv3(conv2, t) + x = self.conv4_0(x, t) + conv4 = self.conv4_1(x, t) + + x = self.conv5(conv4, t) + x = self.conv6_0(x, t) + x = self.conv6_1(x, t) + + x = conv4 + self.conv7(x, t) + x = conv2 + self.conv8(x, t) + x = conv0 + self.conv9(x, t) + return x + +class FrustumTVBlock(nn.Module): + def __init__(self, x_dim, t_dim, v_dim, out_dim, stride): + super().__init__() + norm_act = lambda c: nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 + self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 + self.bn = norm_act(x_dim) + self.silu = nn.SiLU(True) + self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1) + + def forward(self, x, t, v): + x = x + self.t_conv(t) + self.v_conv(v) + return self.conv(self.silu(self.bn(x))) + +class FrustumTVUpBlock(nn.Module): + def __init__(self, x_dim, t_dim, v_dim, out_dim): + super().__init__() + norm_act = lambda c: nn.GroupNorm(8, c) + self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16 + self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16 + self.norm = norm_act(x_dim) + self.silu = nn.SiLU(True) + self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2) + + def forward(self, x, t, v): + x = x + self.t_conv(t) + self.v_conv(v) + return self.conv(self.silu(self.norm(x))) + +class FrustumTV3DNet(nn.Module): + def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)): + super().__init__() + self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32 + + self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2) + self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1) + + self.conv3 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[2], 2) + self.conv4 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[2], 1) + + self.conv5 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[3], 2) + self.conv6 = FrustumTVBlock(dims[3], t_dim, v_dim, dims[3], 1) + + self.up0 = FrustumTVUpBlock(dims[3], t_dim, v_dim, dims[2]) + self.up1 = FrustumTVUpBlock(dims[2], t_dim, v_dim, dims[1]) + self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0]) + + def forward(self, x, t, v): + B,DT = t.shape + t = t.view(B,DT,1,1,1) + B,DV = v.shape + v = v.view(B,DV,1,1,1) + + b, _, d, h, w = x.shape + x0 = self.conv0(x) + x1 = self.conv2(self.conv1(x0, t, v), t, v) + x2 = self.conv4(self.conv3(x1, t, v), t, v) + x3 = self.conv6(self.conv5(x2, t, v), t, v) + + x2 = self.up0(x3, t, v) + x2 + x1 = self.up1(x2, t, v) + x1 + x0 = self.up2(x1, t, v) + x0 + return {w: x0, w//2: x1, w//4: x2, w//8: x3} diff --git a/ldm/models/diffusion/sync_dreamer_utils.py b/ldm/models/diffusion/sync_dreamer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c401c745f498d4fe5435a0e6bea3eedf95c46e29 --- /dev/null +++ b/ldm/models/diffusion/sync_dreamer_utils.py @@ -0,0 +1,103 @@ +import torch +from kornia import create_meshgrid + + +def project_and_normalize(ref_grid, src_proj, length): + """ + + @param ref_grid: b 3 n + @param src_proj: b 4 4 + @param length: int + @return: b, n, 2 + """ + src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n + div_val = src_grid[:, -1:] + div_val[div_val<1e-4] = 1e-4 + src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n) + src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1 + src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1 + src_grid = src_grid.permute(0, 2, 1) # (b, n, 2) + return src_grid + + +def construct_project_matrix(x_ratio, y_ratio, Ks, poses): + """ + @param x_ratio: float + @param y_ratio: float + @param Ks: b,3,3 + @param poses: b,3,4 + @return: + """ + rfn = Ks.shape[0] + scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device) + scale_m = torch.diag(scale_m) + ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4 + pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device) + pad_vals[:, :, 3] = 1.0 + ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4 + return ref_prj + +def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose): + B, _, D, H, W = volume_xyz.shape + ratio = warp_size / input_size + warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4 + warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2) + return warp_coords + + +def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None): + device, dtype = pose_target.device, pose_target.dtype + + # compute a depth range on the unit sphere + H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0] + if near is not None and far is not None : + # near, far b,1,h,w + depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d + depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1 + depth_values = depth_values * (far - near) + near # b d h w + depth_values = depth_values.view(B, 1, D, H * W) + else: + near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1 + depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d + depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1 + depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W) + + ratio = volume_size / input_image_size + + # creat a grid on the target (reference) view + # H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0] + + # creat mesh grid: note reference also means target + ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2) + ref_grid = ref_grid.to(device).to(dtype) + ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W) + ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W) + ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W) + ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W) + ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W) + + # unproject to space and transfer to world coordinates. + Ks = K + ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4 + ref_proj_inv = torch.inverse(ref_proj) # B,4,4 + ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW + return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W) + +def near_far_from_unit_sphere_using_camera_poses(camera_poses): + """ + @param camera_poses: b 3 4 + @return: + near: b,1 + far: b,1 + """ + R_w2c = camera_poses[..., :3, :3] # b 3 3 + t_w2c = camera_poses[..., :3, 3:] # b 3 1 + camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1 + # R_w2c.T @ (0,0,1) = z_dir + camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1 + camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3 + a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1 + b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1 + mid = b / a # b 1 + near, far = mid - 1.0, mid + 1.0 + return near, far \ No newline at end of file diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e04837a20c8d97ef11786f08d4ddc477b0a1c35c --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,336 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) +# feedforward +class ConvGEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = mask>0 + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +class BasicSpatialTransformer(nn.Module): + def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True): + super().__init__() + inner_dim = n_heads * d_head + self.proj_in = nn.Sequential( + nn.GroupNorm(8, dim), + nn.Conv2d(dim, inner_dim, kernel_size=1, stride=1, padding=0), + nn.GroupNorm(8, inner_dim), + nn.ReLU(True), + ) + self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn + self.out_conv = nn.Sequential( + nn.GroupNorm(8, inner_dim), + nn.ReLU(True), + nn.Conv2d(inner_dim, inner_dim, 1, 1), + ) + self.proj_out = nn.Sequential( + nn.GroupNorm(8, inner_dim), + nn.ReLU(True), + zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)), + ) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context): + # input + b,_,h,w = x.shape + x_in = x + x = self.proj_in(x) + + # attention + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + context = rearrange(context, 'b c h w -> b (h w) c').contiguous() + x = self.attn(x, context) + x + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + + # output + x = self.out_conv(x) + x + x = self.proj_out(x) + x_in + return x + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + +class ConvFeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU() + ) if not glu else ConvGEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim_out, 1, 1, 0) + ) + + def forward(self, x): + return self.net(x) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..533e589a2024f1d7c52093d8c472c3b1b6617e26 --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0dc94e240f927985d8edbf2f38aa5ac28641e2 --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,996 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: # False + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) # 0 + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: # always True + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N + emb = self.time_embed(t_emb) # + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) # conv + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a952e6c40308c33edd422da0ce6a60f47e73661b --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c94f4f87a866b174f96aafdf3fcfa50f04e6cbeb --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,550 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from ldm.util import default +import clip + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + +class FaceClipEncoder(AbstractEncoder): + def __init__(self, augment=True, retreival_key=None): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + self.augment = augment + self.retreival_key = retreival_key + + def forward(self, img): + encodings = [] + with torch.no_grad(): + x_offset = 125 + if self.retreival_key: + # Assumes retrieved image are packed into the second half of channels + face = img[:,3:,190:440,x_offset:(512-x_offset)] + other = img[:,:3,...].clone() + else: + face = img[:,:,190:440,x_offset:(512-x_offset)] + other = img.clone() + + if self.augment: + face = K.RandomHorizontalFlip()(face) + + other[:,:,190:440,x_offset:(512-x_offset)] *= 0 + encodings = [ + self.encoder.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class FaceIdClipEncoder(AbstractEncoder): + def __init__(self): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + for p in self.encoder.parameters(): + p.requires_grad = False + self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True) + + def forward(self, img): + encodings = [] + with torch.no_grad(): + face = kornia.geometry.resize(img, (256, 256), + interpolation='bilinear', align_corners=True) + + other = img.clone() + other[:,:,184:452,122:396] *= 0 + encodings = [ + self.id.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +from ldm.thirdp.psp.id_loss import IDFeatures +import kornia.augmentation as K + +class FrozenFaceEncoder(AbstractEncoder): + def __init__(self, model_path, augment=False): + super().__init__() + self.loss_fn = IDFeatures(model_path) + # face encoder is frozen + for p in self.loss_fn.parameters(): + p.requires_grad = False + # Mapper is trainable + self.mapper = torch.nn.Linear(512, 768) + p = 0.25 + if augment: + self.augment = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomEqualize(p=p), + # K.RandomPlanckianJitter(p=p), + # K.RandomPlasmaBrightness(p=p), + # K.RandomPlasmaContrast(p=p), + # K.ColorJiggle(0.02, 0.2, 0.2, p=p), + ) + else: + self.augment = False + + def forward(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 1, 768), device=self.mapper.weight.device) + + if self.augment is not None: + # Transforms require 0-1 + img = self.augment((img + 1)/2) + img = 2*img - 1 + + feat = self.loss_fn(img, crop=True) + feat = self.mapper(feat.unsqueeze(1)) + return feat + + def encode(self, img): + return self(img) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models') + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +import torch.nn.functional as F +from transformers import CLIPVisionModel +class ClipImageProjector(AbstractEncoder): + """ + Uses the CLIP image encoder. + """ + def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.model = CLIPVisionModel.from_pretrained(version) + self.model.train() + self.max_length = max_length # TODO: typical value? + self.antialias = True + self.mapper = torch.nn.Linear(1024, 768) + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + null_cond = self.get_null_cond(version, max_length) + self.register_buffer('null_cond', null_cond) + + @torch.no_grad() + def get_null_cond(self, version, max_length): + device = self.mean.device + embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + null_cond = embedder([""]) + return null_cond + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + if isinstance(x, list): + return self.null_cond + # x is assumed to be in range [-1,1] + x = self.preprocess(x) + outputs = self.model(pixel_values=x) + last_hidden_state = outputs.last_hidden_state + last_hidden_state = self.mapper(last_hidden_state) + return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0]) + + def encode(self, im): + return self(im) + +class ProjectedFrozenCLIPEmbedder(AbstractEncoder): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + self.projection = torch.nn.Linear(768, 768) + + def forward(self, text): + z = self.embedder(text) + return self.projection(z) + + def encode(self, text): + return self(text) + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, 768, device=device) + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) + +from torchvision import transforms +import random + +class FrozenCLIPImageMutliEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=True, + max_crops=5, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.max_crops = max_crops + + def preprocess(self, x): + + # Expects inputs in the range -1, 1 + randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) + max_crops = self.max_crops + patches = [] + crops = [randcrop(x) for _ in range(max_crops)] + patches.extend(crops) + x = torch.cat(patches, dim=0) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, self.max_crops, 768, device=device) + batch_tokens = [] + for im in x: + patches = self.preprocess(im.unsqueeze(0)) + tokens = self.model.encode_image(patches).float() + for t in tokens: + if random.random() < 0.1: + t *= 0 + batch_tokens.append(tokens.unsqueeze(0)) + + return torch.cat(batch_tokens, dim=0) + + def encode(self, im): + return self(im) + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like + + +class LowScaleEncoder(nn.Module): + def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, + scale_factor=1.0): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, + linear_end=linear_end) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + z = self.model.encode(x).sample() + z = z * self.scale_factor + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +if __name__ == "__main__": + from ldm.util import count_params + sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] + model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + model = FrozenCLIPEmbedder().cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + print("done.") diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/thirdp/psp/helpers.py b/ldm/thirdp/psp/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..983baaa50ea9df0cbabe09aba80293ddf7709845 --- /dev/null +++ b/ldm/thirdp/psp/helpers.py @@ -0,0 +1,121 @@ +# https://github.com/eladrich/pixel2style2pixel + +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut \ No newline at end of file diff --git a/ldm/thirdp/psp/id_loss.py b/ldm/thirdp/psp/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e08ee095bd20ff664dcf470de15ff54f839b38e2 --- /dev/null +++ b/ldm/thirdp/psp/id_loss.py @@ -0,0 +1,23 @@ +# https://github.com/eladrich/pixel2style2pixel +import torch +from torch import nn +from ldm.thirdp.psp.model_irse import Backbone + + +class IDFeatures(nn.Module): + def __init__(self, model_path): + super(IDFeatures, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def forward(self, x, crop=False): + # Not sure of the image range here + if crop: + x = torch.nn.functional.interpolate(x, (256, 256), mode="area") + x = x[:, :, 35:223, 32:220] + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats diff --git a/ldm/thirdp/psp/model_irse.py b/ldm/thirdp/psp/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..21cedd2994a6eed5a0afd451b08dd09801fe60c0 --- /dev/null +++ b/ldm/thirdp/psp/model_irse.py @@ -0,0 +1,86 @@ +# https://github.com/eladrich/pixel2style2pixel + +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model \ No newline at end of file diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d4d7dd313111da2bbde8546d58ff43e48b92d --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,302 @@ +import importlib + +import torchvision +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import torch +import time +import cv2 +import PIL + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width)/2 + bottom = (height + width)/2 + else: + + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + +def add_margin(pil_img, color=0, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + + +def create_carvekit_interface(): + from carvekit.api.high import HiInterface + # Check doc strings for more information + interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". + batch_size_seg=5, + batch_size_matting=1, + device='cuda' if torch.cuda.is_available() else 'cpu', + seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net + matting_mask_size=2048, + trimap_prob_threshold=231, + trimap_dilation=30, + trimap_erosion_iters=5, + fp16=False) + + return interface + + +def load_and_preprocess(interface, input_im): + ''' + :param input_im (PIL Image). + :return image (H, W, 3) array in [0, 1]. + ''' + # See https://github.com/Ir1d/image-background-remove-tool + image = input_im.convert('RGB') + + image_without_background = interface([image])[0] + image_without_background = np.array(image_without_background) + est_seg = image_without_background > 127 + image = np.array(image) + foreground = est_seg[:, : , -1].astype(np.bool_) + image[~foreground] = [255., 255., 255.] + x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) + image = image[y:y+h, x:x+w, :] + image = PIL.Image.fromarray(np.array(image)) + + # resize image such that long edge is 512 + image.thumbnail([200, 200], Image.LANCZOS) + image = add_margin(image, (255, 255, 255), size=256) + image = np.array(image) + + return image + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss + +def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256): + image_input = Image.open(image_path) + + if crop_size!=-1: + alpha_np = np.asarray(image_input)[:, :, 3] + coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] + min_x, min_y = np.min(coords, 0) + max_x, max_y = np.max(coords, 0) + ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) + h, w = ref_img_.height, ref_img_.width + scale = crop_size / max(h, w) + h_, w_ = int(scale * h), int(scale * w) + ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) + image_input = add_margin(ref_img_, size=image_size) + else: + image_input = add_margin(image_input, size=max(image_input.height, image_input.width)) + image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC) + + image_input = np.asarray(image_input) + image_input = image_input.astype(np.float32) / 255.0 + ref_mask = image_input[:, :, 3:] + image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background + image_input = image_input[:, :, :3] * 2.0 - 1.0 + image_input = torch.from_numpy(image_input.astype(np.float32)) + elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32)) + return {"input_image": image_input, "input_elevation": elevation_input} \ No newline at end of file diff --git a/meta_info/camera-16.pkl b/meta_info/camera-16.pkl new file mode 100644 index 0000000000000000000000000000000000000000..21aaaffb8a44edff95b1c0d0a1216911341ad772 --- /dev/null +++ b/meta_info/camera-16.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d46537ff53982fd57a7b987673cd759e3a892f11b9f7ee44566f5612d6da6357 +size 2142 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..38e2a2af2e5b61daf7476aa520867dec8dd0901d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +--extra-index-url https://download.pytorch.org/whl/cu113 +torch +pytorch_lightning==1.9.0 +Pillow==10.0.0 +opencv-python +transformers +taming-transformers-rom1504 +tqdm +numpy +kornia +webdataset +omegaconf +einops +scikit-image +pymcubes +carvekit-colab +open3d +trimesh +easydict +nerfacc +imageio-ffmpeg==0.4.7 +fire +segment_anything +git+https://github.com/openai/CLIP.git \ No newline at end of file diff --git a/sam_utils.py b/sam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94095e37216f946f5a832f038f8e6da92806b7c6 --- /dev/null +++ b/sam_utils.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import torch +from PIL import Image +import time + +from segment_anything import sam_model_registry, SamPredictor + +def sam_init(device_id=0): + sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpt/sam_vit_h_4b8939.pth") + model_type = "vit_h" + + device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu" + + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) + predictor = SamPredictor(sam) + return predictor + +def sam_out_nosave(predictor, input_image, bbox): + bbox = np.array(bbox) + image = np.asarray(input_image) + + start_time = time.time() + predictor.set_image(image) + + h, w, _ = image.shape + input_point = np.array([[h//2, w//2]]) + input_label = np.array([1]) + + masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=True, + ) + + masks_bbox, scores_bbox, logits_bbox = predictor.predict( + box=bbox, + multimask_output=True + ) + + print(f"SAM Time: {time.time() - start_time:.3f}s") + opt_idx = np.argmax(scores) + mask = masks[opt_idx] + out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) + out_image[:, :, :3] = image + out_image_bbox = out_image.copy() + out_image[:, :, 3] = mask.astype(np.uint8) * 255 + out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox) + torch.cuda.empty_cache() + return Image.fromarray(out_image_bbox, mode='RGBA') \ No newline at end of file