Update core/models/ddim/ddim_vd.py
Browse files- core/models/ddim/ddim_vd.py +16 -4
core/models/ddim/ddim_vd.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
import numpy as np
|
7 |
from tqdm import tqdm
|
8 |
from functools import partial
|
|
|
9 |
|
10 |
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
11 |
|
@@ -27,7 +28,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
27 |
mix_weight=None,
|
28 |
noise_dropout=0.,
|
29 |
verbose=True,
|
30 |
-
log_every_t=100,
|
|
|
31 |
|
32 |
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
33 |
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
@@ -42,7 +44,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
42 |
noise_dropout=noise_dropout,
|
43 |
temperature=temperature,
|
44 |
log_every_t=log_every_t,
|
45 |
-
mix_weight=mix_weight,
|
|
|
46 |
return samples, intermediates
|
47 |
|
48 |
@torch.no_grad()
|
@@ -58,7 +61,8 @@ class DDIMSampler_VD(DDIMSampler):
|
|
58 |
noise_dropout=0.,
|
59 |
temperature=1.,
|
60 |
mix_weight=None,
|
61 |
-
log_every_t=100,
|
|
|
62 |
|
63 |
device = self.model.device
|
64 |
dtype = condition[0][0].dtype
|
@@ -86,7 +90,12 @@ class DDIMSampler_VD(DDIMSampler):
|
|
86 |
|
87 |
pred_xt = xt
|
88 |
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
|
|
|
|
|
|
89 |
for i, step in enumerate(iterator):
|
|
|
|
|
90 |
index = total_steps - i - 1
|
91 |
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
92 |
|
@@ -107,6 +116,9 @@ class DDIMSampler_VD(DDIMSampler):
|
|
107 |
intermediates['pred_xt'].append(pred_xt)
|
108 |
intermediates['pred_x0'].append(pred_x0)
|
109 |
|
|
|
|
|
|
|
110 |
return pred_xt, intermediates
|
111 |
|
112 |
@torch.no_grad()
|
@@ -172,4 +184,4 @@ class DDIMSampler_VD(DDIMSampler):
|
|
172 |
x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
|
173 |
x_prev.append(x_prev_i)
|
174 |
pred_x0.append(pred_x0_i)
|
175 |
-
return x_prev, pred_x0
|
|
|
6 |
import numpy as np
|
7 |
from tqdm import tqdm
|
8 |
from functools import partial
|
9 |
+
import streamlit as st
|
10 |
|
11 |
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
12 |
|
|
|
28 |
mix_weight=None,
|
29 |
noise_dropout=0.,
|
30 |
verbose=True,
|
31 |
+
log_every_t=100,
|
32 |
+
progress_bar=False, ):
|
33 |
|
34 |
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
35 |
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
|
|
44 |
noise_dropout=noise_dropout,
|
45 |
temperature=temperature,
|
46 |
log_every_t=log_every_t,
|
47 |
+
mix_weight=mix_weight,
|
48 |
+
progress_bar=progress_bar, )
|
49 |
return samples, intermediates
|
50 |
|
51 |
@torch.no_grad()
|
|
|
61 |
noise_dropout=0.,
|
62 |
temperature=1.,
|
63 |
mix_weight=None,
|
64 |
+
log_every_t=100,
|
65 |
+
progress_bar=False,):
|
66 |
|
67 |
device = self.model.device
|
68 |
dtype = condition[0][0].dtype
|
|
|
90 |
|
91 |
pred_xt = xt
|
92 |
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
93 |
+
if progress_bar is not None:
|
94 |
+
progress_bar.progress(0)
|
95 |
+
progress_bar.text("Generating samples...")
|
96 |
for i, step in enumerate(iterator):
|
97 |
+
if progress_bar is not None:
|
98 |
+
progress_bar.progress(i/total_steps)
|
99 |
index = total_steps - i - 1
|
100 |
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
101 |
|
|
|
116 |
intermediates['pred_xt'].append(pred_xt)
|
117 |
intermediates['pred_x0'].append(pred_x0)
|
118 |
|
119 |
+
if progress_bar is not None:
|
120 |
+
progress_bar.success("Sampling complete.")
|
121 |
+
|
122 |
return pred_xt, intermediates
|
123 |
|
124 |
@torch.no_grad()
|
|
|
184 |
x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
|
185 |
x_prev.append(x_prev_i)
|
186 |
pred_x0.append(pred_x0_i)
|
187 |
+
return x_prev, pred_x0
|