haodongli commited on
Commit
dc78df8
β€’
1 Parent(s): 74af050
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. README.md +1 -1
  4. app.py +218 -217
  5. files/videos/K_0005_IN.mp4 +3 -0
  6. files/videos/obama.mp4 +0 -0
  7. infer.py +134 -28
  8. pipeline.py +0 -1
  9. requirements.txt +3 -2
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  files/images/01.jpg filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  files/images/01.jpg filter=lfs diff=lfs merge=lfs -text
37
+ files/videos/K_0005_IN.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ output/
3
+ gradio_cached_examples/
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸš€
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from __future__ import annotations
2
  from gradio_imageslider import ImageSlider
3
  import functools
4
  import os
@@ -14,224 +13,226 @@ from tqdm import tqdm
14
  from pathlib import Path
15
  import gradio
16
  from gradio.utils import get_cache_folder
17
- from infer import lotus
18
-
19
- # def process_image_check(path_input):
20
- # if path_input is None:
21
- # raise gr.Error(
22
- # "Missing image in the first pane: upload a file or use one from the gallery below."
23
- # )
24
-
25
- # def infer(path_input, seed=0):
26
- # print(f"==> Processing image {path_input}")
27
- # return path_input
28
- # return [path_input, path_input]
29
- # # name_base, name_ext = os.path.splitext(os.path.basename(path_input))
30
- # # print(f"==> Processing image {name_base}{name_ext}")
31
- # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- # # print(f"==> Device: {device}")
33
- # # output_g, output_d = lotus(path_input, 'depth', seed, device)
34
- # # if not os.path.exists("files/output"):
35
- # # os.makedirs("files/output")
36
- # # g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}")
37
- # # d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
38
- # # output_g.save(g_save_path)
39
- # # output_d.save(d_save_path)
40
- # # yield [path_input, g_save_path], [path_input, d_save_path]
41
-
42
- # def run_demo_server():
43
- # gradio_theme = gr.themes.Default()
44
-
45
- # with gr.Blocks(
46
- # theme=gradio_theme,
47
- # title="LOTUS (Depth)",
48
- # css="""
49
- # #download {
50
- # height: 118px;
51
- # }
52
- # .slider .inner {
53
- # width: 5px;
54
- # background: #FFF;
55
- # }
56
- # .viewport {
57
- # aspect-ratio: 4/3;
58
- # }
59
- # .tabs button.selected {
60
- # font-size: 20px !important;
61
- # color: crimson !important;
62
- # }
63
- # h1 {
64
- # text-align: center;
65
- # display: block;
66
- # }
67
- # h2 {
68
- # text-align: center;
69
- # display: block;
70
- # }
71
- # h3 {
72
- # text-align: center;
73
- # display: block;
74
- # }
75
- # .md_feedback li {
76
- # margin-bottom: 0px !important;
77
- # }
78
- # """,
79
- # head="""
80
- # <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
81
- # <script>
82
- # window.dataLayer = window.dataLayer || [];
83
- # function gtag() {dataLayer.push(arguments);}
84
- # gtag('js', new Date());
85
- # gtag('config', 'G-1FWSVCGZTG');
86
- # </script>
87
- # """,
88
- # ) as demo:
89
- # gr.Markdown(
90
- # """
91
- # # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction
92
- # <p align="center">
93
- # <a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
94
- # <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white">
95
- # </a>
96
- # <a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
97
- # <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white">
98
- # </a>
99
- # <a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
100
- # <img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
101
- # </a>
102
- # <a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
103
- # <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
104
- # </a>
105
- # """
106
- # )
107
- # with gr.Tabs(elem_classes=["tabs"]):
108
- # with gr.Tab("IMAGE"):
109
- # with gr.Row():
110
- # with gr.Column():
111
- # image_input = gr.Image(
112
- # label="Input Image",
113
- # type="filepath",
114
- # )
115
- # seed = gr.Number(
116
- # label="Seed",
117
- # minimum=0,
118
- # maximum=999999,
119
- # )
120
- # with gr.Row():
121
- # image_submit_btn = gr.Button(
122
- # value="Predict Depth!", variant="primary"
123
- # )
124
- # # image_reset_btn = gr.Button(value="Reset")
125
- # with gr.Column():
126
- # image_output_g = gr.Image(
127
- # label="Output (Generative)",
128
- # type="filepath",
129
- # )
130
- # # image_output_g = ImageSlider(
131
- # # label="Output (Generative)",
132
- # # type="filepath",
133
- # # show_download_button=True,
134
- # # show_share_button=True,
135
- # # interactive=False,
136
- # # elem_classes="slider",
137
- # # position=0.25,
138
- # # )
139
- # # with gr.Row():
140
- # # image_output_d = gr.Image(
141
- # # label="Output (Generative)",
142
- # # type="filepath",
143
- # # )
144
- # # image_output_d = ImageSlider(
145
- # # label="Output (Discriminative)",
146
- # # type="filepath",
147
- # # show_download_button=True,
148
- # # show_share_button=True,
149
- # # interactive=False,
150
- # # elem_classes="slider",
151
- # # position=0.25,
152
- # # )
153
-
154
- # # gr.Examples(
155
- # # fn=infer,
156
- # # examples=sorted([
157
- # # os.path.join("files", "images", name)
158
- # # for name in os.listdir(os.path.join("files", "images"))
159
- # # ]),
160
- # # inputs=[image_input],
161
- # # outputs=[image_output_g],
162
- # # cache_examples=True,
163
- # # )
164
-
165
- # with gr.Tab("VIDEO"):
166
- # with gr.Column():
167
- # gr.Markdown("Coming soon")
168
-
169
- # ### Image
170
- # image_submit_btn.click(
171
- # fn=infer,
172
- # inputs=[
173
- # image_input
174
- # ],
175
- # outputs=image_output_g,
176
- # concurrency_limit=1,
177
- # )
178
- # # image_reset_btn.click(
179
- # # fn=lambda: (
180
- # # None,
181
- # # None,
182
- # # None,
183
- # # ),
184
- # # inputs=[],
185
- # # outputs=image_output_g,
186
- # # queue=False,
187
- # # )
188
-
189
- # ### Video
190
-
191
- # ### Server launch
192
- # demo.queue(
193
- # api_open=False,
194
- # ).launch(
195
- # server_name="0.0.0.0",
196
- # server_port=7860,
197
- # )
198
-
199
- # def main():
200
- # os.system("pip freeze")
201
- # run_demo_server()
202
-
203
- # if __name__ == "__main__":
204
- # main()
205
-
206
- def flip_text(x):
207
- return x[::-1]
208
-
209
- def flip_image(x):
210
- return np.fliplr(x)
 
 
 
 
 
 
211
 
212
- with gr.Blocks() as demo:
213
- gr.Markdown("Flip text or image files using this demo.")
214
- with gr.Tab("Flip Text"):
215
- text_input = gr.Textbox()
216
- text_output = gr.Textbox()
217
- text_button = gr.Button("Flip")
218
- with gr.Tab("Flip Image"):
219
- with gr.Row():
220
- image_input = gr.Image()
221
- image_output = gr.Image()
222
- image_button = gr.Button("Flip")
223
 
224
- with gr.Accordion("Open for More!", open=False):
225
- gr.Markdown("Look at me...")
226
- temp_slider = gr.Slider(
227
- 0, 1,
228
- value=0.1,
229
- step=0.1,
230
- interactive=True,
231
- label="Slide me",
232
  )
233
 
234
- text_button.click(flip_text, inputs=text_input, outputs=text_output)
235
- image_button.click(flip_image, inputs=image_input, outputs=image_output)
 
236
 
237
- demo.launch(share=True)
 
 
 
1
  from gradio_imageslider import ImageSlider
2
  import functools
3
  import os
 
13
  from pathlib import Path
14
  import gradio
15
  from gradio.utils import get_cache_folder
16
+ from infer import lotus, lotus_video
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def infer(path_input, seed=0):
21
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
22
+ output_g, output_d = lotus(path_input, 'depth', seed, device)
23
+ if not os.path.exists("files/output"):
24
+ os.makedirs("files/output")
25
+ g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}")
26
+ d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
27
+ output_g.save(g_save_path)
28
+ output_d.save(d_save_path)
29
+ return [path_input, g_save_path], [path_input, d_save_path]
30
+
31
+ def infer_video(path_input, seed=0):
32
+ frames_g, frames_d = lotus_video(path_input, 'depth', seed, device)
33
+ if not os.path.exists("files/output"):
34
+ os.makedirs("files/output")
35
+ name_base, _ = os.path.splitext(os.path.basename(path_input))
36
+ g_save_path = os.path.join("files/output", f"{name_base}_g.mp4")
37
+ d_save_path = os.path.join("files/output", f"{name_base}_d.mp4")
38
+ imageio.mimsave(g_save_path, frames_g)
39
+ imageio.mimsave(d_save_path, frames_d)
40
+ return [g_save_path, d_save_path]
41
+
42
+ def run_demo_server():
43
+ gradio_theme = gr.themes.Default()
44
+
45
+ with gr.Blocks(
46
+ theme=gradio_theme,
47
+ title="LOTUS (Depth)",
48
+ css="""
49
+ #download {
50
+ height: 118px;
51
+ }
52
+ .slider .inner {
53
+ width: 5px;
54
+ background: #FFF;
55
+ }
56
+ .viewport {
57
+ aspect-ratio: 4/3;
58
+ }
59
+ .tabs button.selected {
60
+ font-size: 20px !important;
61
+ color: crimson !important;
62
+ }
63
+ h1 {
64
+ text-align: center;
65
+ display: block;
66
+ }
67
+ h2 {
68
+ text-align: center;
69
+ display: block;
70
+ }
71
+ h3 {
72
+ text-align: center;
73
+ display: block;
74
+ }
75
+ .md_feedback li {
76
+ margin-bottom: 0px !important;
77
+ }
78
+ """,
79
+ head="""
80
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
81
+ <script>
82
+ window.dataLayer = window.dataLayer || [];
83
+ function gtag() {dataLayer.push(arguments);}
84
+ gtag('js', new Date());
85
+ gtag('config', 'G-1FWSVCGZTG');
86
+ </script>
87
+ """,
88
+ ) as demo:
89
+ gr.Markdown(
90
+ """
91
+ # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction
92
+ <p align="center">
93
+ <a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
94
+ <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white">
95
+ </a>
96
+ <a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
97
+ <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white">
98
+ </a>
99
+ <a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
100
+ <img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
101
+ </a>
102
+ <a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
103
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
104
+ </a>
105
+ """
106
+ )
107
+ with gr.Tabs(elem_classes=["tabs"]):
108
+ with gr.Tab("IMAGE"):
109
+ with gr.Row():
110
+ with gr.Column():
111
+ image_input = gr.Image(
112
+ label="Input Image",
113
+ type="filepath",
114
+ )
115
+ seed = gr.Number(
116
+ label="Seed (only for Generative mode)",
117
+ minimum=0,
118
+ maximum=999999999,
119
+ )
120
+ with gr.Row():
121
+ image_submit_btn = gr.Button(
122
+ value="Predict Depth!", variant="primary"
123
+ )
124
+ image_reset_btn = gr.Button(value="Reset")
125
+ with gr.Column():
126
+ image_output_g = ImageSlider(
127
+ label="Output (Generative)",
128
+ type="filepath",
129
+ interactive=False,
130
+ elem_classes="slider",
131
+ position=0.25,
132
+ )
133
+ with gr.Row():
134
+ image_output_d = ImageSlider(
135
+ label="Output (Discriminative)",
136
+ type="filepath",
137
+ interactive=False,
138
+ elem_classes="slider",
139
+ position=0.25,
140
+ )
141
+
142
+ gr.Examples(
143
+ fn=infer,
144
+ examples=sorted([
145
+ os.path.join("files", "images", name)
146
+ for name in os.listdir(os.path.join("files", "images"))
147
+ ]),
148
+ inputs=[image_input],
149
+ outputs=[image_output_g, image_output_d],
150
+ cache_examples=True,
151
+ )
152
+
153
+ with gr.Tab("VIDEO"):
154
+ with gr.Row():
155
+ with gr.Column():
156
+ input_video = gr.Video(
157
+ label="Input Video",
158
+ autoplay=True,
159
+ loop=True,
160
+ )
161
+ seed = gr.Number(
162
+ label="Seed (only for Generative mode)",
163
+ minimum=0,
164
+ maximum=999999999,
165
+ )
166
+ with gr.Row():
167
+ video_submit_btn = gr.Button(
168
+ value="Compute Depth!", variant="primary"
169
+ )
170
+ video_reset_btn = gr.Button(value="Reset")
171
+ with gr.Column():
172
+ video_output_g = gr.Video(
173
+ label="Output (Generative)",
174
+ interactive=False,
175
+ autoplay=True,
176
+ loop=True,
177
+ show_share_button=True,
178
+ )
179
+ with gr.Row():
180
+ video_output_d = gr.Video(
181
+ label="Output (Discriminative)",
182
+ interactive=False,
183
+ autoplay=True,
184
+ loop=True,
185
+ show_share_button=True,
186
+ )
187
+
188
+ gr.Examples(
189
+ fn=infer_video,
190
+ examples=sorted([
191
+ os.path.join("files", "videos", name)
192
+ for name in os.listdir(os.path.join("files", "videos"))
193
+ ]),
194
+ inputs=[input_video],
195
+ outputs=[video_output_g, video_output_d],
196
+ cache_examples=True,
197
+ )
198
+
199
+ ### Image
200
+ image_submit_btn.click(
201
+ fn=infer,
202
+ inputs=[image_input, seed],
203
+ outputs=[image_output_g, image_output_d],
204
+ concurrency_limit=1,
205
+ )
206
+ image_reset_btn.click(
207
+ fn=lambda: (
208
+ None,
209
+ None,
210
+ None,
211
+ ),
212
+ inputs=[],
213
+ outputs=[image_output_g, image_output_d],
214
+ queue=False,
215
+ )
216
 
217
+ ### Video
218
+ video_submit_btn.click(
219
+ fn=infer_video,
220
+ inputs=[input_video, seed],
221
+ outputs=[video_output_g, video_output_d],
222
+ queue=True,
223
+ )
 
 
 
 
224
 
225
+ ### Server launch
226
+ demo.queue(
227
+ api_open=False,
228
+ ).launch(
229
+ server_name="0.0.0.0",
230
+ server_port=7860,
 
 
231
  )
232
 
233
+ def main():
234
+ os.system("pip freeze")
235
+ run_demo_server()
236
 
237
+ if __name__ == "__main__":
238
+ main()
files/videos/K_0005_IN.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a532ba2738716dbb244e0d7172cf681879218cbbdad09980404fa08ef6b9ecc
3
+ size 3095352
files/videos/obama.mp4 DELETED
Binary file (320 kB)
 
infer.py CHANGED
@@ -14,6 +14,9 @@ from pipeline import LotusGPipeline, LotusDPipeline
14
  from utils.image_utils import colorize_depth_map
15
  from utils.seed_all import seed_all
16
 
 
 
 
17
  check_min_version('0.28.0.dev0')
18
 
19
  def infer_pipe(pipe, image_input, task_name, seed, device):
@@ -22,36 +25,137 @@ def infer_pipe(pipe, image_input, task_name, seed, device):
22
  else:
23
  generator = torch.Generator(device=device).manual_seed(seed)
24
 
25
- test_image = Image.open(image_input).convert('RGB')
26
- test_image = np.array(test_image).astype(np.float32)
27
- test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
28
- test_image = test_image / 127.5 - 1.0
29
- test_image = test_image.to(device)
30
 
31
- task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
32
- task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Run
35
- pred = pipe(
36
- rgb_in=test_image,
37
- prompt='',
38
- num_inference_steps=1,
39
- generator=generator,
40
- # guidance_scale=0,
41
- output_type='np',
42
- timesteps=[999],
43
- task_emb=task_emb,
44
- ).images[0]
45
-
46
- # Post-process the prediction
47
  if task_name == 'depth':
48
- output_npy = pred.mean(axis=-1)
49
- output_color = colorize_depth_map(output_npy)
50
  else:
51
- output_npy = pred
52
- output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
53
 
54
- return output_color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def lotus(image_input, task_name, seed, device):
57
  if task_name == 'depth':
@@ -61,7 +165,7 @@ def lotus(image_input, task_name, seed, device):
61
  model_g = 'jingheya/lotus-normal-g-v1-0'
62
  model_d = 'jingheya/lotus-normal-d-v1-0'
63
 
64
- dtype = torch.float32
65
  pipe_g = LotusGPipeline.from_pretrained(
66
  model_g,
67
  torch_dtype=dtype,
@@ -72,6 +176,8 @@ def lotus(image_input, task_name, seed, device):
72
  )
73
  pipe_g.to(device)
74
  pipe_d.to(device)
 
 
75
  logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
76
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
77
  output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
@@ -158,7 +264,7 @@ def main():
158
  dtype = torch.float16
159
  logging.info(f"Running with half precision ({dtype}).")
160
  else:
161
- dtype = torch.float32
162
 
163
  # -------------------- Device --------------------
164
  if torch.cuda.is_available():
@@ -206,7 +312,7 @@ def main():
206
  for i in tqdm(range(len(test_images))):
207
  # Preprocess validation image
208
  test_image = Image.open(test_images[i]).convert('RGB')
209
- test_image = np.array(test_image).astype(np.float32)
210
  test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
211
  test_image = test_image / 127.5 - 1.0
212
  test_image = test_image.to(device)
 
14
  from utils.image_utils import colorize_depth_map
15
  from utils.seed_all import seed_all
16
 
17
+ from contextlib import nullcontext
18
+ import cv2
19
+
20
  check_min_version('0.28.0.dev0')
21
 
22
  def infer_pipe(pipe, image_input, task_name, seed, device):
 
25
  else:
26
  generator = torch.Generator(device=device).manual_seed(seed)
27
 
28
+ if torch.backends.mps.is_available():
29
+ autocast_ctx = nullcontext()
30
+ else:
31
+ autocast_ctx = torch.autocast(pipe.device.type)
32
+ with autocast_ctx:
33
 
34
+ test_image = Image.open(image_input).convert('RGB')
35
+ test_image = np.array(test_image).astype(np.float16)
36
+ test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
37
+ test_image = test_image / 127.5 - 1.0
38
+ test_image = test_image.to(device)
39
+
40
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
41
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
42
+
43
+ # Run
44
+ pred = pipe(
45
+ rgb_in=test_image,
46
+ prompt='',
47
+ num_inference_steps=1,
48
+ generator=generator,
49
+ # guidance_scale=0,
50
+ output_type='np',
51
+ timesteps=[999],
52
+ task_emb=task_emb,
53
+ ).images[0]
54
+
55
+ # Post-process the prediction
56
+ if task_name == 'depth':
57
+ output_npy = pred.mean(axis=-1)
58
+ output_color = colorize_depth_map(output_npy)
59
+ else:
60
+ output_npy = pred
61
+ output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
62
+
63
+ return output_color
64
 
65
+ def lotus_video(input_video, task_name, seed, device):
 
 
 
 
 
 
 
 
 
 
 
 
66
  if task_name == 'depth':
67
+ model_g = 'jingheya/lotus-depth-g-v1-0'
68
+ model_d = 'jingheya/lotus-depth-d-v1-0'
69
  else:
70
+ model_g = 'jingheya/lotus-normal-g-v1-0'
71
+ model_d = 'jingheya/lotus-normal-d-v1-0'
72
 
73
+ dtype = torch.float16
74
+ pipe_g = LotusGPipeline.from_pretrained(
75
+ model_g,
76
+ torch_dtype=dtype,
77
+ )
78
+ pipe_d = LotusDPipeline.from_pretrained(
79
+ model_d,
80
+ torch_dtype=dtype,
81
+ )
82
+ pipe_g.to(device)
83
+ pipe_d.to(device)
84
+ pipe_g.set_progress_bar_config(disable=True)
85
+ pipe_d.set_progress_bar_config(disable=True)
86
+ logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
87
+
88
+ # load the video and split it into frames
89
+ cap = cv2.VideoCapture(input_video)
90
+ frames = []
91
+ while True:
92
+ ret, frame = cap.read()
93
+ if not ret:
94
+ break
95
+ frames.append(frame)
96
+ cap.release()
97
+ logging.info(f"There are {len(frames)} frames in the video.")
98
+
99
+ if seed is None:
100
+ generator = None
101
+ else:
102
+ generator = torch.Generator(device=device).manual_seed(seed)
103
+
104
+ task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
105
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)
106
+
107
+ output_g = []
108
+ output_d = []
109
+ for frame in frames:
110
+ if torch.backends.mps.is_available():
111
+ autocast_ctx = nullcontext()
112
+ else:
113
+ autocast_ctx = torch.autocast(pipe_g.device.type)
114
+ with autocast_ctx:
115
+ test_image = frame
116
+ test_image = np.array(test_image).astype(np.float16)
117
+ test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
118
+ test_image = test_image / 127.5 - 1.0
119
+ test_image = test_image.to(device)
120
+
121
+ # Run
122
+ pred_g = pipe_g(
123
+ rgb_in=test_image,
124
+ prompt='',
125
+ num_inference_steps=1,
126
+ generator=generator,
127
+ # guidance_scale=0,
128
+ output_type='np',
129
+ timesteps=[999],
130
+ task_emb=task_emb,
131
+ ).images[0]
132
+ pred_d = pipe_d(
133
+ rgb_in=test_image,
134
+ prompt='',
135
+ num_inference_steps=1,
136
+ generator=generator,
137
+ # guidance_scale=0,
138
+ output_type='np',
139
+ timesteps=[999],
140
+ task_emb=task_emb,
141
+ ).images[0]
142
+
143
+ # Post-process the prediction
144
+ if task_name == 'depth':
145
+ output_npy_g = pred_g.mean(axis=-1)
146
+ output_color_g = colorize_depth_map(output_npy_g)
147
+ output_npy_d = pred_d.mean(axis=-1)
148
+ output_color_d = colorize_depth_map(output_npy_d)
149
+ else:
150
+ output_npy_g = pred_g
151
+ output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))
152
+ output_npy_d = pred_d
153
+ output_color_d = Image.fromarray((output_npy_d * 255).astype(np.uint8))
154
+
155
+ output_g.append(output_color_g)
156
+ output_d.append(output_color_d)
157
+
158
+ return output_g, output_d
159
 
160
  def lotus(image_input, task_name, seed, device):
161
  if task_name == 'depth':
 
165
  model_g = 'jingheya/lotus-normal-g-v1-0'
166
  model_d = 'jingheya/lotus-normal-d-v1-0'
167
 
168
+ dtype = torch.float16
169
  pipe_g = LotusGPipeline.from_pretrained(
170
  model_g,
171
  torch_dtype=dtype,
 
176
  )
177
  pipe_g.to(device)
178
  pipe_d.to(device)
179
+ pipe_g.set_progress_bar_config(disable=True)
180
+ pipe_d.set_progress_bar_config(disable=True)
181
  logging.info(f"Successfully loading pipeline from {model_g} and {model_d}.")
182
  output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
183
  output_d = infer_pipe(pipe_d, image_input, task_name, seed, device)
 
264
  dtype = torch.float16
265
  logging.info(f"Running with half precision ({dtype}).")
266
  else:
267
+ dtype = torch.float16
268
 
269
  # -------------------- Device --------------------
270
  if torch.cuda.is_available():
 
312
  for i in tqdm(range(len(test_images))):
313
  # Preprocess validation image
314
  test_image = Image.open(test_images[i]).convert('RGB')
315
+ test_image = np.array(test_image).astype(np.float16)
316
  test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
317
  test_image = test_image / 127.5 - 1.0
318
  test_image = test_image.to(device)
pipeline.py CHANGED
@@ -1197,7 +1197,6 @@ class LotusGPipeline(DirectDiffusionPipeline):
1197
  # 2. Define call parameters
1198
  batch_size = rgb_in.shape[0]
1199
  device = self._execution_device
1200
- print("Device: ", device)
1201
 
1202
  # 3. Encode input prompt
1203
  prompt_embeds, _ = self.encode_prompt(
 
1197
  # 2. Define call parameters
1198
  batch_size = rgb_in.shape[0]
1199
  device = self._execution_device
 
1200
 
1201
  # 3. Encode input prompt
1202
  prompt_embeds, _ = self.encode_prompt(
requirements.txt CHANGED
@@ -17,7 +17,8 @@ h5py==3.11.0
17
  omegaconf==2.3.0
18
  tabulate==0.9.0
19
  imageio==2.35.1
 
20
  spaces==0.28.3
21
- gradio==4.21.0
22
  gradio-imageslider==0.0.16
23
- gradio_client==0.12.0
 
17
  omegaconf==2.3.0
18
  tabulate==0.9.0
19
  imageio==2.35.1
20
+ imageio-ffmpeg==0.5.1
21
  spaces==0.28.3
22
+ gradio==4.44.0
23
  gradio-imageslider==0.0.16
24
+ gradio-client==1.3.0