rliu commited on
Commit
096f631
1 Parent(s): b94d219

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -219
app.py DELETED
@@ -1,219 +0,0 @@
1
- import math
2
- import fire
3
- import gradio as gr
4
- import numpy as np
5
- import rich
6
- import torch
7
- from contextlib import nullcontext
8
- from einops import rearrange
9
- from functools import partial
10
- from ldm.models.diffusion.ddim import DDIMSampler
11
- from ldm.util import load_and_preprocess, instantiate_from_config
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- from rich import print
15
- from torch import autocast
16
- from torchvision import transforms
17
-
18
-
19
- _SHOW_INTERMEDIATE = True
20
- _GPU_INDEX = 0
21
- # _GPU_INDEX = 2
22
-
23
-
24
- def load_model_from_config(config, ckpt, device, verbose=False):
25
- print(f'Loading model from {ckpt}')
26
- pl_sd = torch.load(ckpt, map_location=device)
27
- if 'global_step' in pl_sd:
28
- print(f'Global Step: {pl_sd["global_step"]}')
29
- sd = pl_sd['state_dict']
30
- model = instantiate_from_config(config.model)
31
- m, u = model.load_state_dict(sd, strict=False)
32
- if len(m) > 0 and verbose:
33
- print('missing keys:')
34
- print(m)
35
- if len(u) > 0 and verbose:
36
- print('unexpected keys:')
37
- print(u)
38
-
39
- model.to(device)
40
- model.eval()
41
- return model
42
-
43
-
44
- @torch.no_grad()
45
- def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale,
46
- ddim_eta, x, y, z):
47
- precision_scope = autocast if precision == 'autocast' else nullcontext
48
- with precision_scope('cuda'):
49
- with model.ema_scope():
50
- c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
51
- T = torch.tensor([math.radians(x), math.sin(
52
- math.radians(y)), math.cos(math.radians(y)), z])
53
- T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device)
54
- c = torch.cat([c, T], dim=-1)
55
- c = model.cc_projection(c)
56
- cond = {}
57
- cond['c_crossattn'] = [c]
58
- c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
59
- cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()
60
- .repeat(n_samples, 1, 1, 1)]
61
- if scale != 1.0:
62
- uc = {}
63
- uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)]
64
- uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
65
- else:
66
- uc = None
67
-
68
- shape = [4, h // 8, w // 8]
69
- samples_ddim, _ = sampler.sample(S=ddim_steps,
70
- conditioning=cond,
71
- batch_size=n_samples,
72
- shape=shape,
73
- verbose=False,
74
- unconditional_guidance_scale=scale,
75
- unconditional_conditioning=uc,
76
- eta=ddim_eta,
77
- x_T=None)
78
- print(samples_ddim.shape)
79
- # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
80
- x_samples_ddim = model.decode_first_stage(samples_ddim)
81
- return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
82
-
83
-
84
- def main(
85
- model,
86
- device,
87
- input_im,
88
- preprocess=True,
89
- x=0.,
90
- y=0.,
91
- z=0.,
92
- scale=3.0,
93
- n_samples=4,
94
- ddim_steps=50,
95
- ddim_eta=1.0,
96
- precision='fp32',
97
- h=256,
98
- w=256,
99
- ):
100
- # input_im[input_im == [0., 0., 0.]] = [1., 1., 1., 1.]
101
- print('old input_im:', input_im.size)
102
-
103
- if preprocess:
104
- input_im = load_and_preprocess(input_im)
105
- input_im = (input_im / 255.0).astype(np.float32)
106
- # (H, W, 3) array in [0, 1].
107
-
108
- else:
109
- input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
110
- input_im = np.asarray(input_im, dtype=np.float32) / 255.0
111
- # (H, W, 4) array in [0, 1].
112
-
113
- # old method: very important, thresholding background
114
- # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.]
115
-
116
- # new method: apply correct method of compositing to avoid sudden transitions / thresholding
117
- # (smoothly transition foreground to white background based on alpha values)
118
- alpha = input_im[:, :, 3:4]
119
- white_im = np.ones_like(input_im)
120
- input_im = alpha * input_im + (1.0 - alpha) * white_im
121
-
122
- input_im = input_im[:, :, 0:3]
123
- # (H, W, 3) array in [0, 1].
124
-
125
- print('new input_im:', input_im.shape, input_im.dtype, input_im.min(), input_im.max())
126
- show_in_im = Image.fromarray((input_im * 255).astype(np.uint8))
127
-
128
- input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
129
- input_im = input_im * 2 - 1
130
- input_im = transforms.functional.resize(input_im, [h, w])
131
-
132
- sampler = DDIMSampler(model)
133
- x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w,
134
- ddim_steps, n_samples, scale, ddim_eta, x, y, z)
135
-
136
- output_ims = []
137
- for x_sample in x_samples_ddim:
138
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
139
- output_ims.append(Image.fromarray(x_sample.astype(np.uint8)))
140
-
141
- if _SHOW_INTERMEDIATE:
142
- return (output_ims, show_in_im)
143
- else:
144
- return output_ims
145
-
146
-
147
- description = '''
148
- Generate novel viewpoints of an object depicted in one input image using a fine-tuned version of Stable Diffusion.
149
- '''
150
-
151
- article = '''
152
- ## How to use this?
153
- TBD
154
- ## How does this work?
155
- TBD
156
- '''
157
-
158
-
159
- def run_demo(
160
- device_idx=_GPU_INDEX,
161
- ckpt='last.ckpt',
162
- config='configs/sd-objaverse-finetune-c_concat-256.yaml',
163
- ):
164
-
165
- device = f'cuda:{device_idx}'
166
- config = OmegaConf.load(config)
167
- model = load_model_from_config(config, ckpt, device=device)
168
-
169
- inputs = [
170
- gr.Image(type='pil', image_mode='RGBA'), # shape=[512, 512]
171
- gr.Checkbox(True, label='Preprocess image (remove background and center)',
172
- info='If enabled, the uploaded image will be preprocessed to remove the background and center the object by cropping and/or padding as necessary. '
173
- 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'),
174
- # gr.Number(label='polar (between axis z+)'),
175
- # gr.Number(label='azimuth (between axis x+)'),
176
- # gr.Number(label='z (distance from center)'),
177
- gr.Slider(-90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)',
178
- info='Positive values move the camera down, while negative values move the camera up.'),
179
- gr.Slider(-90, 90, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)',
180
- info='Positive values move the camera right, while negative values move the camera left.'),
181
- gr.Slider(-2, 2, value=0, step=0.5, label='Radius (distance from center)',
182
- info='Positive values move the camera further away, while negative values move the camera closer.'),
183
- gr.Slider(0, 30, value=3, step=1, label='cfg scale'),
184
- gr.Slider(1, 8, value=4, step=1, label='Number of samples to generate'),
185
- gr.Slider(5, 200, value=100, step=5, label='Number of steps'),
186
- ]
187
- output = [gr.Gallery(label='Generated images from specified new viewpoint')]
188
- output[0].style(grid=2)
189
-
190
- if _SHOW_INTERMEDIATE:
191
- output += [gr.Image(type='pil', image_mode='RGB', label='Preprocessed input image')]
192
-
193
- fn_with_model = partial(main, model, device)
194
- fn_with_model.__name__ = 'fn_with_model'
195
-
196
- examples = [
197
- # ['assets/zero-shot/bear.png', 0, 0, 0, 3, 4, 100],
198
- # ['assets/zero-shot/car.png', 0, 0, 0, 3, 4, 100],
199
- # ['assets/zero-shot/elephant.png', 0, 0, 0, 3, 4, 100],
200
- # ['assets/zero-shot/pikachu.png', 0, 0, 0, 3, 4, 100],
201
- # ['assets/zero-shot/spyro.png', 0, 0, 0, 3, 4, 100],
202
- # ['assets/zero-shot/taxi.png', 0, 0, 0, 3, 4, 100],
203
- ]
204
-
205
- demo = gr.Interface(
206
- fn=fn_with_model,
207
- title='Demo for Zero-Shot Control of Camera Viewpoints within a Single Image',
208
- description=description,
209
- article=article,
210
- inputs=inputs,
211
- outputs=output,
212
- examples=examples,
213
- allow_flagging='never',
214
- )
215
- demo.launch(enable_queue=True, share=True)
216
-
217
-
218
- if __name__ == '__main__':
219
- fire.Fire(run_demo)