Spaces:
Build error
Build error
Wryley1234
commited on
Commit
•
ec70094
1
Parent(s):
c956a18
Coding key
Browse files
Unknown
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from torchvision.datasets.utils import download_url
|
3 |
+
from ldm.util import instantiate_from_config
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
# todo ?
|
7 |
+
from google.colab import files
|
8 |
+
from IPython.display import Image as ipyimg
|
9 |
+
import ipywidgets as widgets
|
10 |
+
from PIL import Image
|
11 |
+
from numpy import asarray
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
import torch, torchvision
|
14 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
15 |
+
from ldm.util import ismap
|
16 |
+
import time
|
17 |
+
from omegaconf import OmegaConf
|
18 |
+
|
19 |
+
|
20 |
+
def download_models(mode):
|
21 |
+
|
22 |
+
if mode == "superresolution":
|
23 |
+
# this is the small bsr light model
|
24 |
+
url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
|
25 |
+
url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
|
26 |
+
|
27 |
+
path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml'
|
28 |
+
path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt'
|
29 |
+
|
30 |
+
download_url(url_conf, path_conf)
|
31 |
+
download_url(url_ckpt, path_ckpt)
|
32 |
+
|
33 |
+
path_conf = path_conf + '/?dl=1' # fix it
|
34 |
+
path_ckpt = path_ckpt + '/?dl=1' # fix it
|
35 |
+
return path_conf, path_ckpt
|
36 |
+
|
37 |
+
else:
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
|
41 |
+
def load_model_from_config(config, ckpt):
|
42 |
+
print(f"Loading model from {ckpt}")
|
43 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
44 |
+
global_step = pl_sd["global_step"]
|
45 |
+
sd = pl_sd["state_dict"]
|
46 |
+
model = instantiate_from_config(config.model)
|
47 |
+
m, u = model.load_state_dict(sd, strict=False)
|
48 |
+
model.cuda()
|
49 |
+
model.eval()
|
50 |
+
return {"model": model}, global_step
|
51 |
+
|
52 |
+
|
53 |
+
def get_model(mode):
|
54 |
+
path_conf, path_ckpt = download_models(mode)
|
55 |
+
config = OmegaConf.load(path_conf)
|
56 |
+
model, step = load_model_from_config(config, path_ckpt)
|
57 |
+
return model
|
58 |
+
|
59 |
+
|
60 |
+
def get_custom_cond(mode):
|
61 |
+
dest = "data/example_conditioning"
|
62 |
+
|
63 |
+
if mode == "superresolution":
|
64 |
+
uploaded_img = files.upload()
|
65 |
+
filename = next(iter(uploaded_img))
|
66 |
+
name, filetype = filename.split(".") # todo assumes just one dot in name !
|
67 |
+
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
|
68 |
+
|
69 |
+
elif mode == "text_conditional":
|
70 |
+
w = widgets.Text(value='A cake with cream!', disabled=True)
|
71 |
+
display(w)
|
72 |
+
|
73 |
+
with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
|
74 |
+
f.write(w.value)
|
75 |
+
|
76 |
+
elif mode == "class_conditional":
|
77 |
+
w = widgets.IntSlider(min=0, max=1000)
|
78 |
+
display(w)
|
79 |
+
with open(f"{dest}/{mode}/custom.txt", 'w') as f:
|
80 |
+
f.write(w.value)
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise NotImplementedError(f"cond not implemented for mode{mode}")
|
84 |
+
|
85 |
+
|
86 |
+
def get_cond_options(mode):
|
87 |
+
path = "data/example_conditioning"
|
88 |
+
path = os.path.join(path, mode)
|
89 |
+
onlyfiles = [f for f in sorted(os.listdir(path))]
|
90 |
+
return path, onlyfiles
|
91 |
+
|
92 |
+
|
93 |
+
def select_cond_path(mode):
|
94 |
+
path = "data/example_conditioning" # todo
|
95 |
+
path = os.path.join(path, mode)
|
96 |
+
onlyfiles = [f for f in sorted(os.listdir(path))]
|
97 |
+
|
98 |
+
selected = widgets.RadioButtons(
|
99 |
+
options=onlyfiles,
|
100 |
+
description='Select conditioning:',
|
101 |
+
disabled=False
|
102 |
+
)
|
103 |
+
display(selected)
|
104 |
+
selected_path = os.path.join(path, selected.value)
|
105 |
+
return selected_path
|
106 |
+
|
107 |
+
|
108 |
+
def get_cond(mode, selected_path):
|
109 |
+
example = dict()
|
110 |
+
if mode == "superresolution":
|
111 |
+
up_f = 4
|
112 |
+
visualize_cond_img(selected_path)
|
113 |
+
|
114 |
+
c = Image.open(selected_path)
|
115 |
+
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
116 |
+
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
|
117 |
+
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
118 |
+
c = rearrange(c, '1 c h w -> 1 h w c')
|
119 |
+
c = 2. * c - 1.
|
120 |
+
|
121 |
+
c = c.to(torch.device("cuda"))
|
122 |
+
example["LR_image"] = c
|
123 |
+
example["image"] = c_up
|
124 |
+
|
125 |
+
return example
|
126 |
+
|
127 |
+
|
128 |
+
def visualize_cond_img(path):
|
129 |
+
display(ipyimg(filename=path))
|
130 |
+
|
131 |
+
|
132 |
+
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
|
133 |
+
|
134 |
+
example = get_cond(task, selected_path)
|
135 |
+
|
136 |
+
save_intermediate_vid = False
|
137 |
+
n_runs = 1
|
138 |
+
masked = False
|
139 |
+
guider = None
|
140 |
+
ckwargs = None
|
141 |
+
mode = 'ddim'
|
142 |
+
ddim_use_x0_pred = False
|
143 |
+
temperature = 1.
|
144 |
+
eta = 1.
|
145 |
+
make_progrow = True
|
146 |
+
custom_shape = None
|
147 |
+
|
148 |
+
height, width = example["image"].shape[1:3]
|
149 |
+
split_input = height >= 128 and width >= 128
|
150 |
+
|
151 |
+
if split_input:
|
152 |
+
ks = 128
|
153 |
+
stride = 64
|
154 |
+
vqf = 4 #
|
155 |
+
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
156 |
+
"vqf": vqf,
|
157 |
+
"patch_distributed_vq": True,
|
158 |
+
"tie_braker": False,
|
159 |
+
"clip_max_weight": 0.5,
|
160 |
+
"clip_min_weight": 0.01,
|
161 |
+
"clip_max_tie_weight": 0.5,
|
162 |
+
"clip_min_tie_weight": 0.01}
|
163 |
+
else:
|
164 |
+
if hasattr(model, "split_input_params"):
|
165 |
+
delattr(model, "split_input_params")
|
166 |
+
|
167 |
+
invert_mask = False
|
168 |
+
|
169 |
+
x_T = None
|
170 |
+
for n in range(n_runs):
|
171 |
+
if custom_shape is not None:
|
172 |
+
x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
173 |
+
x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
|
174 |
+
|
175 |
+
logs = make_convolutional_sample(example, model,
|
176 |
+
mode=mode, custom_steps=custom_steps,
|
177 |
+
eta=eta, swap_mode=False , masked=masked,
|
178 |
+
invert_mask=invert_mask, quantize_x0=False,
|
179 |
+
custom_schedule=None, decode_interval=10,
|
180 |
+
resize_enabled=resize_enabled, custom_shape=custom_shape,
|
181 |
+
temperature=temperature, noise_dropout=0.,
|
182 |
+
corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
|
183 |
+
make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
|
184 |
+
)
|
185 |
+
return logs
|
186 |
+
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
190 |
+
mask=None, x0=None, quantize_x0=False, img_callback=None,
|
191 |
+
temperature=1., noise_dropout=0., score_corrector=None,
|
192 |
+
corrector_kwargs=None, x_T=None, log_every_t=None
|
193 |
+
):
|
194 |
+
|
195 |
+
ddim = DDIMSampler(model)
|
196 |
+
bs = shape[0] # dont know where this comes from but wayne
|
197 |
+
shape = shape[1:] # cut batch dim
|
198 |
+
print(f"Sampling with eta = {eta}; steps: {steps}")
|
199 |
+
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
200 |
+
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
201 |
+
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
202 |
+
score_corrector=score_corrector,
|
203 |
+
corrector_kwargs=corrector_kwargs, x_T=x_T)
|
204 |
+
|
205 |
+
return samples, intermediates
|
206 |
+
|
207 |
+
|
208 |
+
@torch.no_grad()
|
209 |
+
def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
|
210 |
+
invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
|
211 |
+
resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
212 |
+
corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
|
213 |
+
log = dict()
|
214 |
+
|
215 |
+
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
216 |
+
return_first_stage_outputs=True,
|
217 |
+
force_c_encode=not (hasattr(model, 'split_input_params')
|
218 |
+
and model.cond_stage_key == 'coordinates_bbox'),
|
219 |
+
return_original_cond=True)
|
220 |
+
|
221 |
+
log_every_t = 1 if save_intermediate_vid else None
|
222 |
+
|
223 |
+
if custom_shape is not None:
|
224 |
+
z = torch.randn(custom_shape)
|
225 |
+
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
226 |
+
|
227 |
+
z0 = None
|