dmolino commited on
Commit
8773294
·
verified ·
1 Parent(s): 78b8f59

Update core/models/ddim/ddim_vd.py

Browse files
Files changed (1) hide show
  1. core/models/ddim/ddim_vd.py +16 -4
core/models/ddim/ddim_vd.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import numpy as np
7
  from tqdm import tqdm
8
  from functools import partial
 
9
 
10
  from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
11
 
@@ -27,7 +28,8 @@ class DDIMSampler_VD(DDIMSampler):
27
  mix_weight=None,
28
  noise_dropout=0.,
29
  verbose=True,
30
- log_every_t=100, ):
 
31
 
32
  self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
33
  print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
@@ -42,7 +44,8 @@ class DDIMSampler_VD(DDIMSampler):
42
  noise_dropout=noise_dropout,
43
  temperature=temperature,
44
  log_every_t=log_every_t,
45
- mix_weight=mix_weight, )
 
46
  return samples, intermediates
47
 
48
  @torch.no_grad()
@@ -58,7 +61,8 @@ class DDIMSampler_VD(DDIMSampler):
58
  noise_dropout=0.,
59
  temperature=1.,
60
  mix_weight=None,
61
- log_every_t=100, ):
 
62
 
63
  device = self.model.device
64
  dtype = condition[0][0].dtype
@@ -86,7 +90,12 @@ class DDIMSampler_VD(DDIMSampler):
86
 
87
  pred_xt = xt
88
  iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
 
 
 
89
  for i, step in enumerate(iterator):
 
 
90
  index = total_steps - i - 1
91
  ts = torch.full((bs,), step, device=device, dtype=torch.long)
92
 
@@ -107,6 +116,9 @@ class DDIMSampler_VD(DDIMSampler):
107
  intermediates['pred_xt'].append(pred_xt)
108
  intermediates['pred_x0'].append(pred_x0)
109
 
 
 
 
110
  return pred_xt, intermediates
111
 
112
  @torch.no_grad()
@@ -172,4 +184,4 @@ class DDIMSampler_VD(DDIMSampler):
172
  x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
173
  x_prev.append(x_prev_i)
174
  pred_x0.append(pred_x0_i)
175
- return x_prev, pred_x0
 
6
  import numpy as np
7
  from tqdm import tqdm
8
  from functools import partial
9
+ import streamlit as st
10
 
11
  from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
12
 
 
28
  mix_weight=None,
29
  noise_dropout=0.,
30
  verbose=True,
31
+ log_every_t=100,
32
+ progress_bar=False, ):
33
 
34
  self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
35
  print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
 
44
  noise_dropout=noise_dropout,
45
  temperature=temperature,
46
  log_every_t=log_every_t,
47
+ mix_weight=mix_weight,
48
+ progress_bar=progress_bar, )
49
  return samples, intermediates
50
 
51
  @torch.no_grad()
 
61
  noise_dropout=0.,
62
  temperature=1.,
63
  mix_weight=None,
64
+ log_every_t=100,
65
+ progress_bar=False,):
66
 
67
  device = self.model.device
68
  dtype = condition[0][0].dtype
 
90
 
91
  pred_xt = xt
92
  iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
93
+ if progress_bar is not None:
94
+ progress_bar.progress(0)
95
+ progress_bar.text("Generating samples...")
96
  for i, step in enumerate(iterator):
97
+ if progress_bar is not None:
98
+ progress_bar.progress(i/total_steps)
99
  index = total_steps - i - 1
100
  ts = torch.full((bs,), step, device=device, dtype=torch.long)
101
 
 
116
  intermediates['pred_xt'].append(pred_xt)
117
  intermediates['pred_x0'].append(pred_x0)
118
 
119
+ if progress_bar is not None:
120
+ progress_bar.success("Sampling complete.")
121
+
122
  return pred_xt, intermediates
123
 
124
  @torch.no_grad()
 
184
  x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
185
  x_prev.append(x_prev_i)
186
  pred_x0.append(pred_x0_i)
187
+ return x_prev, pred_x0