alvan commited on
Commit
0f0e0b1
1 Parent(s): 2706768

Added gradio app

Browse files
Files changed (5) hide show
  1. app.py +151 -0
  2. cool_models.py +132 -0
  3. requirements.txt +13 -0
  4. run_edit.py +288 -0
  5. weights/rd64-uni.pth +3 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import random
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from run_edit import run_model
10
+ from cool_models import make_models
11
+
12
+ help_text = """"""
13
+
14
+ example_instructions = [
15
+ "Make it a picasso painting",
16
+ "as if it were by modigliani",
17
+ "convert to a bronze statue",
18
+ "Turn it into an anime.",
19
+ "have it look like a graphic novel",
20
+ "make him gain weight",
21
+ "what would he look like bald?",
22
+ "Have him smile",
23
+ "Put him in a cocktail party.",
24
+ "move him at the beach.",
25
+ "add dramatic lighting",
26
+ "Convert to black and white",
27
+ "What if it were snowing?",
28
+ "Give him a leather jacket",
29
+ "Turn him into a cyborg!",
30
+ "make him wear a beanie",
31
+ ]
32
+
33
+ model_id = "timbrooks/instruct-pix2pix"
34
+
35
+ def main():
36
+ # pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
37
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models()
38
+
39
+ def generate(
40
+ input_image: Image.Image,
41
+ from_text: str,
42
+ instruction: str,
43
+ negative_prompt: str,
44
+ randomize_seed: bool,
45
+ seed: int,
46
+ guidance_scale: float,
47
+ clip_guidance_scale: float,
48
+ cutn: int,
49
+ l2_sim_lambda: float
50
+ ):
51
+ seed = random.randint(0, 100000) if randomize_seed else seed
52
+
53
+ if instruction == "":
54
+ return [seed, input_image]
55
+
56
+ generator = torch.manual_seed(seed)
57
+
58
+ edited_image_1 = run_model(
59
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params,
60
+ from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
61
+ )
62
+
63
+ # edited_image = input_image
64
+ return [seed, edited_image_1]
65
+
66
+ def reset():
67
+ return [
68
+ "Randomize Seed", 1371, None, 5.0,
69
+ 150, 16, 10000
70
+ ]
71
+
72
+ with gr.Blocks() as demo:
73
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
74
+ RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing
75
+ </h1>
76
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
77
+ <br/>
78
+ <a href="https://huggingface.co/spaces/timbrooks/instruct-pix2pix?duplicate=true">
79
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
80
+ <p/>""")
81
+ with gr.Row():
82
+ with gr.Column(scale=1, min_width=100):
83
+ generate_button = gr.Button("Generate")
84
+ # with gr.Column(scale=1, min_width=100):
85
+ # load_button = gr.Button("Load Example")
86
+ with gr.Column(scale=1, min_width=100):
87
+ reset_button = gr.Button("Reset")
88
+ with gr.Column(scale=3):
89
+ from_text = gr.Textbox(lines=1, label="From Text", interactive=True)
90
+ instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
91
+ negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True)
92
+
93
+ with gr.Row():
94
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
95
+ edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
96
+ # edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
97
+ input_image.style(height=512, width=512)
98
+ edited_image_1.style(height=512, width=512)
99
+ # edited_image_2.style(height=512, width=512)
100
+
101
+ with gr.Row():
102
+ # steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
103
+ seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
104
+ guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True)
105
+ clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True)
106
+ cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True)
107
+ l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image")
108
+
109
+ randomize_seed = gr.Radio(
110
+ ["Fix Seed", "Randomize Seed"],
111
+ value="Randomize Seed",
112
+ type="index",
113
+ show_label=False,
114
+ interactive=True,
115
+ )
116
+ # use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True)
117
+ # use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True)
118
+
119
+ gr.Markdown(help_text)
120
+
121
+ generate_button.click(
122
+ fn=generate,
123
+ inputs=[
124
+ input_image,
125
+ from_text,
126
+ instruction,
127
+ negative_prompt,
128
+ randomize_seed,
129
+ seed,
130
+ guidance_scale,
131
+ clip_guidance_scale,
132
+ cutn,
133
+ l2_sim_lambda
134
+ ],
135
+ outputs=[seed, edited_image_1],
136
+ )
137
+ reset_button.click(
138
+ fn=reset,
139
+ inputs=[],
140
+ outputs=[
141
+ randomize_seed, seed, edited_image_1, guidance_scale,
142
+ clip_guidance_scale, cutn, l2_sim_lambda
143
+ ],
144
+ )
145
+
146
+ demo.queue(concurrency_count=1)
147
+ demo.launch(share=False, server_name="0.0.0.0")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main()
cool_models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
3
+ import lpips
4
+ import clip
5
+
6
+
7
+ from encoders.modules import BERTEmbedder
8
+ from models.clipseg import CLIPDensePredT
9
+
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ STEPS = 100
13
+ USE_DDPM = False
14
+ USE_DDIM = False
15
+ USE_CPU = False
16
+ BERT_PATH = "./weights/bert.pt"
17
+ KL_PATH = "./weights/kl-f8.pt"
18
+ INPAINT_PATH = "./weights/inpaint.pt"
19
+ CLIP_SEG_PATH = './weights/rd64-uni.pth'
20
+ CLIP_GUIDANCE = False
21
+
22
+ def make_models():
23
+ segmodel = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
24
+ segmodel.eval()
25
+
26
+ # non-strict, because we only stored decoder weights (not CLIP weights)
27
+ segmodel.load_state_dict(torch.load(CLIP_SEG_PATH, map_location=torch.device('cpu')), strict=False)
28
+ # segmodel.save_pretrained("./weights/hf_clipseg")
29
+
30
+ device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu')
31
+ print('Using device:', device)
32
+
33
+ hf_inpaint_path = hf_hub_download("alvanlii/rdm_inpaint", "inpaint.pt")
34
+ model_state_dict = torch.load(hf_inpaint_path, map_location='cpu')
35
+
36
+ # print(
37
+ # 'hey',
38
+ # 'clip_proj.weight' in model_state_dict, # True
39
+ # model_state_dict['input_blocks.0.0.weight'].shape[1] == 8, # True
40
+ # 'external_block.0.0.weight' in model_state_dict # False
41
+ # )
42
+
43
+ model_params = {
44
+ 'attention_resolutions': '32,16,8',
45
+ 'class_cond': False,
46
+ 'diffusion_steps': 1000,
47
+ 'rescale_timesteps': True,
48
+ 'timestep_respacing': STEPS, # Modify this value to decrease the number of
49
+ # timesteps.
50
+ 'image_size': 32,
51
+ 'learn_sigma': False,
52
+ 'noise_schedule': 'linear',
53
+ 'num_channels': 320,
54
+ 'num_heads': 8,
55
+ 'num_res_blocks': 2,
56
+ 'resblock_updown': False,
57
+ 'use_fp16': False,
58
+ 'use_scale_shift_norm': False,
59
+ 'clip_embed_dim': 768,
60
+ 'image_condition': True,
61
+ 'super_res_condition': False,
62
+ }
63
+
64
+ if USE_DDPM:
65
+ model_params['timestep_respacing'] = '1000'
66
+ if USE_DDIM:
67
+ if STEPS:
68
+ model_params['timestep_respacing'] = 'ddim'+str(STEPS)
69
+ else:
70
+ model_params['timestep_respacing'] = 'ddim50'
71
+ elif STEPS:
72
+ model_params['timestep_respacing'] = str(STEPS)
73
+
74
+ model_config = model_and_diffusion_defaults()
75
+ model_config.update(model_params)
76
+
77
+ if USE_CPU:
78
+ model_config['use_fp16'] = False
79
+
80
+
81
+ model, diffusion = create_model_and_diffusion(**model_config)
82
+
83
+ # model.from_pretrained("alvanlii/rdm_inpaint")
84
+ model.load_state_dict(model_state_dict, strict=False)
85
+ # model.save_pretrained("./weights/hf_inpaint")
86
+
87
+ model.requires_grad_(CLIP_GUIDANCE).eval().to(device)
88
+
89
+ if model_config['use_fp16']:
90
+ model.convert_to_fp16()
91
+ else:
92
+ model.convert_to_fp32()
93
+
94
+ def set_requires_grad(model, value):
95
+ for param in model.parameters():
96
+ param.requires_grad = value
97
+
98
+
99
+ lpips_model = lpips.LPIPS(net="vgg").to(device)
100
+ hf_kl_path = hf_hub_download("alvanlii/rdm_kl", "kl-f8.pt")
101
+
102
+ # kl_model_url = hf_hub_url("alvanlii/rdm_kl", "kl-f8.pt")
103
+ # kl_cache_path = cached_download(kl_model_url, cache_dir=".")
104
+
105
+ ldm = torch.load(hf_kl_path, map_location="cpu")
106
+
107
+ # torch.save(ldm, "./weights/hf_ldm")
108
+ ldm.to(device)
109
+ ldm.eval()
110
+ ldm.requires_grad_(CLIP_GUIDANCE)
111
+ set_requires_grad(ldm, CLIP_GUIDANCE)
112
+
113
+ bert = BERTEmbedder(1280, 32)
114
+ hf_bert_path = hf_hub_download("alvanlii/rdm_bert", 'bert.pt')
115
+ # bert = BERTEmbedder.from_pretrained("alvanlii/rdm_bert")
116
+ sd = torch.load(hf_bert_path, map_location="cpu")
117
+ bert.load_state_dict(sd)
118
+ # bert.save_pretrained("./weights/hf_bert")
119
+
120
+ bert.to(device)
121
+ bert.half().eval()
122
+ set_requires_grad(bert, False)
123
+
124
+
125
+ clip_model, clip_preprocess = clip.load('ViT-L/14', device=device, jit=False)
126
+ clip_model.eval().requires_grad_(False)
127
+
128
+ return segmodel, model, diffusion, ldm, bert, clip_model, model_params
129
+
130
+
131
+ if __name__ == "__main__":
132
+ make_models()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.6.0
2
+ lpips
3
+ gradio
4
+ opencv-python
5
+ --extra-index-url https://download.pytorch.org/whl/cu116
6
+ torch
7
+ --extra-index-url https://download.pytorch.org/whl/cu116
8
+ torchvision
9
+ transformers
10
+ pytorch-lightning
11
+ git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
12
+ git+https://github.com/openai/CLIP.git@main#egg=clip
13
+ git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion
run_edit.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import io
4
+ import math
5
+ import sys
6
+ import tempfile
7
+
8
+ from PIL import Image, ImageOps
9
+ import requests
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torchvision import transforms
14
+ from torchvision.transforms import functional as TF
15
+ from tqdm.notebook import tqdm
16
+
17
+ import numpy as np
18
+
19
+ from math import log2, sqrt
20
+
21
+ import argparse
22
+ import pickle
23
+
24
+
25
+
26
+
27
+ ################################### mask_fusion ######################################
28
+ from util.metrics_accumulator import MetricsAccumulator
29
+ metrics_accumulator = MetricsAccumulator()
30
+
31
+ from pathlib import Path
32
+ from PIL import Image
33
+ ################################### mask_fusion ######################################
34
+
35
+ import clip
36
+ import lpips
37
+ from torch.nn.functional import mse_loss
38
+
39
+ ################################### CLIPseg ######################################
40
+ from torchvision import utils as vutils
41
+ import cv2
42
+
43
+ ################################### CLIPseg ######################################
44
+
45
+ def str2bool(x):
46
+ return x.lower() in ('true')
47
+
48
+ USE_CPU = False
49
+ device = torch.device('cuda:0' if (torch.cuda.is_available() and not USE_CPU) else 'cpu')
50
+
51
+
52
+ def fetch(url_or_path):
53
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
54
+ r = requests.get(url_or_path)
55
+ r.raise_for_status()
56
+ fd = io.BytesIO()
57
+ fd.write(r.content)
58
+ fd.seek(0)
59
+ return fd
60
+ return open(url_or_path, 'rb')
61
+
62
+
63
+ class MakeCutouts(nn.Module):
64
+ def __init__(self, cut_size, cutn, cut_pow=1.):
65
+ super().__init__()
66
+
67
+ self.cut_size = cut_size
68
+ self.cutn = cutn
69
+ self.cut_pow = cut_pow
70
+
71
+ def forward(self, input):
72
+ sideY, sideX = input.shape[2:4]
73
+ max_size = min(sideX, sideY)
74
+ min_size = min(sideX, sideY, self.cut_size)
75
+ cutouts = []
76
+ for _ in range(self.cutn):
77
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
78
+ offsetx = torch.randint(0, sideX - size + 1, ())
79
+ offsety = torch.randint(0, sideY - size + 1, ())
80
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
81
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
82
+ return torch.cat(cutouts)
83
+
84
+ def spherical_dist_loss(x, y):
85
+ x = F.normalize(x, dim=-1)
86
+ y = F.normalize(y, dim=-1)
87
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
88
+
89
+
90
+ def do_run(
91
+ arg_seed, arg_text, arg_batch_size, arg_num_batches, arg_negative, arg_cutn, arg_edit, arg_height, arg_width,
92
+ arg_edit_y, arg_edit_x, arg_edit_width, arg_edit_height, mask, arg_guidance_scale, arg_background_preservation_loss,
93
+ arg_lpips_sim_lambda, arg_l2_sim_lambda, arg_ddpm, arg_ddim, arg_enforce_background, arg_clip_guidance_scale,
94
+ arg_clip_guidance, model_params, model, diffusion, ldm, bert, clip_model
95
+ ):
96
+ normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
97
+
98
+ if arg_seed >= 0:
99
+ torch.manual_seed(arg_seed)
100
+
101
+ text_emb = bert.encode([arg_text] * arg_batch_size).to(device).float()
102
+ text_blank = bert.encode([arg_negative] * arg_batch_size).to(device).float()
103
+
104
+ text = clip.tokenize([arg_text] * arg_batch_size, truncate=True).to(device)
105
+ text_clip_blank = clip.tokenize([arg_negative] * arg_batch_size, truncate=True).to(device)
106
+
107
+
108
+
109
+ text_emb_clip = clip_model.encode_text(text)
110
+ text_emb_clip_blank = clip_model.encode_text(text_clip_blank)
111
+ make_cutouts = MakeCutouts(clip_model.visual.input_resolution, arg_cutn)
112
+ text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True)
113
+ image_embed = None
114
+
115
+ if arg_edit:
116
+ w = arg_edit_width if arg_edit_width else arg_width
117
+ h = arg_edit_height if arg_edit_height else arg_height
118
+
119
+ arg_edit = arg_edit.convert('RGB')
120
+ input_image_pil = arg_edit
121
+
122
+ init_image_pil = input_image_pil.resize((arg_height, arg_width), Image.Resampling.LANCZOS)
123
+
124
+ input_image_pil = ImageOps.fit(input_image_pil, (w, h))
125
+
126
+ im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device)
127
+
128
+ init_image = (TF.to_tensor(init_image_pil).to(device).unsqueeze(0).mul(2).sub(1))
129
+
130
+ im = 2*im-1
131
+ im = ldm.encode(im).sample()
132
+
133
+ y = arg_edit_y//8
134
+ x = arg_edit_x//8
135
+
136
+ input_image = torch.zeros(1, 4, arg_height//8, arg_width//8, device=device)
137
+
138
+ ycrop = y + im.shape[2] - input_image.shape[2]
139
+ xcrop = x + im.shape[3] - input_image.shape[3]
140
+
141
+ ycrop = ycrop if ycrop > 0 else 0
142
+ xcrop = xcrop if xcrop > 0 else 0
143
+
144
+ input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop]
145
+
146
+ input_image_pil = ldm.decode(input_image)
147
+ input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1))
148
+
149
+ input_image *= 0.18215
150
+
151
+ new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width//8, arg_height//8))
152
+
153
+ mask1 = (new_mask > 0.5)
154
+ mask1 = mask1.float()
155
+
156
+ input_image *= mask1
157
+
158
+ image_embed = torch.cat(arg_batch_size*2*[input_image], dim=0).float()
159
+ elif model_params['image_condition']:
160
+ # using inpaint model but no image is provided
161
+ image_embed = torch.zeros(arg_batch_size*2, 4, arg_height//8, arg_width//8, device=device)
162
+
163
+ kwargs = {
164
+ "context": torch.cat([text_emb, text_blank], dim=0).float(),
165
+ "clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None,
166
+ "image_embed": image_embed
167
+ }
168
+
169
+ # Create a classifier-free guidance sampling function
170
+ def model_fn(x_t, ts, **kwargs):
171
+ half = x_t[: len(x_t) // 2]
172
+ combined = torch.cat([half, half], dim=0)
173
+ model_out = model(combined, ts, **kwargs)
174
+ eps, rest = model_out[:, :3], model_out[:, 3:]
175
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
176
+ half_eps = uncond_eps + arg_guidance_scale * (cond_eps - uncond_eps)
177
+ eps = torch.cat([half_eps, half_eps], dim=0)
178
+ return torch.cat([eps, rest], dim=1)
179
+
180
+ cur_t = None
181
+
182
+ @torch.no_grad()
183
+ def postprocess_fn(out, t):
184
+ if mask is not None:
185
+ background_stage_t = diffusion.q_sample(init_image, t[0])
186
+ background_stage_t = torch.tile(
187
+ background_stage_t, dims=(arg_batch_size, 1, 1, 1)
188
+ )
189
+ out["sample"] = out["sample"] * mask + background_stage_t * (1 - mask)
190
+ return out
191
+
192
+ # if arg_ddpm:
193
+ # sample_fn = diffusion.p_sample_loop_progressive
194
+ # elif arg_ddim:
195
+ # sample_fn = diffusion.ddim_sample_loop_progressive
196
+ # else:
197
+ sample_fn = diffusion.plms_sample_loop_progressive
198
+
199
+ def save_sample(i, sample):
200
+ out_ims = []
201
+ for k, image in enumerate(sample['pred_xstart'][:arg_batch_size]):
202
+ image /= 0.18215
203
+ im = image.unsqueeze(0)
204
+ out = ldm.decode(im)
205
+ metrics_accumulator.print_average_metric()
206
+
207
+ for b in range(arg_batch_size):
208
+ pred_image = sample["pred_xstart"][b]
209
+
210
+ if arg_enforce_background:
211
+ new_mask = TF.resize(mask.unsqueeze(0).unsqueeze(0).to(device), (arg_width, arg_height))
212
+ pred_image = (
213
+ init_image[0] * new_mask[0] + out * (1 - new_mask[0])
214
+ )
215
+
216
+ pred_image_pil = TF.to_pil_image(pred_image.squeeze(0).add(1).div(2).clamp(0, 1))
217
+ out_ims.append(pred_image_pil)
218
+ return out_ims
219
+
220
+
221
+ all_saved_ims = []
222
+ for i in range(arg_num_batches):
223
+ cur_t = diffusion.num_timesteps - 1
224
+
225
+ samples = sample_fn(
226
+ model_fn,
227
+ (arg_batch_size*2, 4, int(arg_height//8), int(arg_width//8)),
228
+ clip_denoised=False,
229
+ model_kwargs=kwargs,
230
+ cond_fn=None,
231
+ device=device,
232
+ progress=True,
233
+ )
234
+
235
+ for j, sample in enumerate(samples):
236
+ cur_t -= 1
237
+ if j % 5 == 0 and j != diffusion.num_timesteps - 1:
238
+ all_saved_ims += save_sample(i, sample)
239
+ all_saved_ims += save_sample(i, sample)
240
+
241
+ return all_saved_ims
242
+
243
+ def run_model(
244
+ segmodel, model, diffusion, ldm, bert, clip_model, model_params,
245
+ from_text, instruction, negative_prompt, original_img, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
246
+ ):
247
+ input_image = original_img
248
+
249
+ transform = transforms.Compose([
250
+ transforms.ToTensor(),
251
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
252
+ transforms.Resize((256, 256)),
253
+ ])
254
+ img = transform(input_image).unsqueeze(0)
255
+
256
+ with torch.no_grad():
257
+ preds = segmodel(img.repeat(1,1,1,1), from_text)[0]
258
+
259
+ mask = torch.sigmoid(preds[0][0])
260
+ image = (mask.detach().cpu().numpy() * 255).astype(np.uint8) # cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
261
+ ret, thresh = cv2.threshold(image, 100, 255, cv2.THRESH_TRUNC, image)
262
+ timg = np.array(thresh)
263
+ x, y = timg.shape
264
+ for row in range(x):
265
+ for col in range(y):
266
+ if (timg[row][col]) == 100:
267
+ timg[row][col] = 255
268
+ if (timg[row][col]) < 100:
269
+ timg[row][col] = 0
270
+
271
+ fulltensor = torch.full_like(mask, fill_value=255)
272
+ bgtensor = fulltensor-timg
273
+ mask = bgtensor / 255.0
274
+
275
+ gc.collect()
276
+ use_ddim = False
277
+ use_ddpm = False
278
+ all_saved_ims = do_run(
279
+ seed, instruction, 1, 1, negative_prompt, cutn, input_image, 256, 256,
280
+ 0, 0, 0, 0, mask, guidance_scale, True,
281
+ 1000, l2_sim_lambda, use_ddpm, use_ddim, True, clip_guidance_scale, False,
282
+ model_params, model, diffusion, ldm, bert, clip_model
283
+ )
284
+
285
+ return all_saved_ims[-1]
286
+
287
+
288
+
weights/rd64-uni.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13845f6cee4d54ca46f62ee19dd354822094a26e0efccc64e606be93d6a7e26f
3
+ size 4306645