File size: 4,411 Bytes
7b88137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import logging
import os
import pdb

from peft import LoraConfig, get_peft_model
import torch
from safetensors.torch import load_model, save_model
from marigold.marigold_inpaint_pipeline import MarigoldInpaintPipeline
from marigold.duplicate_unet import DoubleUNet2DConditionModel
import json
from depth_anything_v2.dpt import DepthAnythingV2
from torchvision.transforms.functional import pil_to_tensor
from PIL import Image
import random
import numpy as np
from pycocotools import mask as coco_mask
from diffusers.schedulers import DDIMScheduler, PNDMScheduler
from torchvision.transforms import InterpolationMode, Resize, CenterCrop
import torchvision.transforms as transforms

model = MarigoldInpaintPipeline.from_pretrained('stabilityai/stable-diffusion-2')
unet_config_path = '/home/aiops/wangzh/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2/snapshots/1e128c8891e52218b74cde8f26dbfc701cb99d79/unet/config.json'
# unet_checkpoint_path = '/home/aiops/wangzh/marigold/768_gen/diffusion_pytorch_model.safetensors'
model.unet = DoubleUNet2DConditionModel(**json.load(open(unet_config_path)))
# model.unet.load_state_dict(torch.load(unet_checkpoint_path, map_location='cpu'), strict=False)

model.unet.config["in_channels"] = 13
model.unet.duplicate_model()
model.unet.inpaint_rgb_conv_in()
model.unet.inpaint_depth_conv_in()

unet_lora_config = LoraConfig(
            r=128,
            lora_alpha=128,
            init_lora_weights="gaussian",
            target_modules=['to_k','to_q','to_v','to_out.0'],
        )
model.unet = get_peft_model(model.unet, unet_lora_config)

sd2inpaint_ckpt = torch.load('/home/aiops/wangzh/marigold/output/512-inpaint-0.5-128-vitl-partition/checkpoint/latest/pytorch_model.bin', map_location='cpu')
model.unet.load_state_dict(sd2inpaint_ckpt)
model.to('cuda')

model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}

model.rgb_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
model.depth_scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")

depth_model = DepthAnythingV2(**model_configs['vitl'])
depth_model.load_state_dict(
    torch.load(f'/home/aiops/wangzh/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu'))
depth_model = depth_model.to('cuda').eval()

image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg',
              '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg',
              '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg']

prompt = ['A white car is parked in front of the factory',
          'church with cemetery next to it',
          'A house with a red brick roof']

imgs = [pil_to_tensor(Image.open(p)) for p in image_path]
depth_imgs = [depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs]

masks = []
for rgb_path in image_path:
    anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
    random.shuffle(anno)
    object_num = random.randint(5, 10)
    mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
    for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
        mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
    mask = mask
    mask = torch.stack([torch.tensor(mask) * 3], dim=0)
    masks.append(mask)

    # mask = torch.zeros((512,512))
    # mask[100:300, 200:400] = 1
    # masks.append(mask)

resize_transform = Resize(size=[512, 512], interpolation=InterpolationMode.NEAREST_EXACT)
imgs = [resize_transform(img) for img in imgs]
depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs]
masks = [resize_transform(mask.unsqueeze(0)) for mask in masks]

# for gs in [1,2,3,4,5]:
for i in range(len(imgs)):
    output_image = model._rgbd_inpaint(imgs[i], depth_imgs[i].unsqueeze(0), masks[i], [prompt[i]], processing_res=512,
                                       guidance_scale=3, mode='joint_inpaint'  #'full_rgb_depth_inpaint', 'full_depth_rgb_inpaint', 'joint_inpaint'
    )
    output_image.save(f'./joint-{i}.jpg')