zzl
commited on
Commit
•
2bbc3ee
1
Parent(s):
9fd9429
release
Browse files- app.py +6 -7
- demo_img.py +34 -10
- demo_vid.py +20 -2
- utils.py +18 -0
app.py
CHANGED
@@ -14,19 +14,18 @@ with gr.Blocks(css='style.css') as demo:
|
|
14 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
15 |
<a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
|
16 |
<a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
|
17 |
-
<a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1
|
18 |
-
<a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1
|
19 |
-
<a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1
|
20 |
-
<a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1
|
21 |
</h2>
|
22 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
23 |
-
<sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution
|
24 |
</h2>
|
25 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
26 |
[<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
|
27 |
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
|
28 |
-
[<a href="https://
|
29 |
-
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">Replicate</a>]
|
30 |
</h2>
|
31 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
32 |
"""
|
|
|
14 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
15 |
<a href="https://paper99.github.io" style="color:blue;">Zhen Li</a><sup>1*</sup>,
|
16 |
<a href="https://github.com/NK-CS-ZZL" style="color:blue;">Zuo-Liang Zhu</a><sup>1*</sup>,
|
17 |
+
<a href="https://github.com/hlh981029" style="color:blue;">Ling-Hao Han</a><sup>1</sup>,
|
18 |
+
<a href="https://houqb.github.io" style="color:blue;">Qibin Hou</a><sup>1</sup>,
|
19 |
+
<a href="https://github.com" style="color:blue;">Chun-Le Guo</a><sup>1</sup>,
|
20 |
+
<a href="https://mmcheng.net" style="color:blue;">Ming-Ming Cheng</a><sup>1</sup>,
|
21 |
</h2>
|
22 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
23 |
+
<sup>1</sup>Nankai University <sup>*</sup> represents the equal contribution.
|
24 |
</h2>
|
25 |
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
26 |
[<a href="https://arxiv.org/abs/2303.13439" style="color:blue;">arXiv</a>]
|
27 |
[<a href="https://github.com/MCG-NKU/AMT" style="color:blue;">GitHub</a>]
|
28 |
+
[<a href="https://colab.research.google.com/drive/1IeVO5BmLouhRh6fL2z_y18kgubotoaBq?usp=sharing" style="color:blue;">Colab</a>]
|
|
|
29 |
</h2>
|
30 |
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
31 |
"""
|
demo_img.py
CHANGED
@@ -7,8 +7,11 @@ from huggingface_hub import hf_hub_download
|
|
7 |
from networks.amts import Model as AMTS
|
8 |
from networks.amtl import Model as AMTL
|
9 |
from networks.amtg import Model as AMTG
|
10 |
-
from utils import
|
11 |
-
|
|
|
|
|
|
|
12 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
model_dict = {
|
14 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
@@ -23,22 +26,43 @@ def img2vid(model_type, img0, img1, frame_ratio, iters):
|
|
23 |
model.eval()
|
24 |
img0_t = img2tensor(img0).to(device)
|
25 |
img1_t = img2tensor(img1).to(device)
|
26 |
-
padder = InputPadder(img0_t.shape, 16)
|
27 |
-
img0_t, img1_t = padder.pad(img0_t, img1_t)
|
28 |
inputs = [img0_t, img1_t]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
for i in range(iters):
|
32 |
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
|
33 |
-
outputs = [
|
34 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
|
|
|
|
35 |
with torch.no_grad():
|
36 |
-
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
|
37 |
-
|
38 |
-
in_1 = padder.unpad(in_1)
|
39 |
-
outputs += [imgt_pred, in_1]
|
40 |
inputs = outputs
|
41 |
-
|
42 |
out_path = 'results'
|
43 |
size = outputs[0].shape[2:][::-1]
|
44 |
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
|
|
|
7 |
from networks.amts import Model as AMTS
|
8 |
from networks.amtl import Model as AMTL
|
9 |
from networks.amtg import Model as AMTG
|
10 |
+
from utils import (
|
11 |
+
img2tensor, tensor2img,
|
12 |
+
InputPadder,
|
13 |
+
check_dim_and_resize
|
14 |
+
)
|
15 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
model_dict = {
|
17 |
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
|
|
|
26 |
model.eval()
|
27 |
img0_t = img2tensor(img0).to(device)
|
28 |
img1_t = img2tensor(img1).to(device)
|
|
|
|
|
29 |
inputs = [img0_t, img1_t]
|
30 |
+
|
31 |
+
if device == 'cpu':
|
32 |
+
# Do not resize in cpu mode
|
33 |
+
anchor_resolution = 8192*8192
|
34 |
+
anchor_memory = 1
|
35 |
+
anchor_memory_bias = 0
|
36 |
+
vram_avail = 1
|
37 |
+
elif device == 'cuda':
|
38 |
+
anchor_resolution = 1024 * 512
|
39 |
+
anchor_memory = 1500 * 1024**2
|
40 |
+
anchor_memory_bias = 2500 * 1024**2
|
41 |
+
vram_avail = torch.cuda.get_device_properties(device).total_memory
|
42 |
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
|
43 |
|
44 |
+
inputs = check_dim_and_resize(inputs)
|
45 |
+
h, w = inputs[0].shape[-2:]
|
46 |
+
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
|
47 |
+
scale = 1 if scale > 1 else scale
|
48 |
+
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
49 |
+
if scale < 1:
|
50 |
+
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
|
51 |
+
padding = int(16 / scale)
|
52 |
+
padder = InputPadder(inputs[0].shape, padding)
|
53 |
+
inputs = padder.pad(*inputs)
|
54 |
+
|
55 |
for i in range(iters):
|
56 |
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
|
57 |
+
outputs = [inputs[0]]
|
58 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
59 |
+
in_0 = in_0.to(device)
|
60 |
+
in_1 = in_1.to(device)
|
61 |
with torch.no_grad():
|
62 |
+
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
|
63 |
+
outputs += [imgt_pred.cpu(), in_1.cpu()]
|
|
|
|
|
64 |
inputs = outputs
|
65 |
+
outputs = padder.unpad(*outputs)
|
66 |
out_path = 'results'
|
67 |
size = outputs[0].shape[2:][::-1]
|
68 |
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
|
demo_vid.py
CHANGED
@@ -27,7 +27,25 @@ def vid2vid(model_type, video, iters):
|
|
27 |
inputs = []
|
28 |
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
29 |
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
while True:
|
32 |
ret, frame = vcap.read()
|
33 |
if ret is False:
|
@@ -43,7 +61,7 @@ def vid2vid(model_type, video, iters):
|
|
43 |
outputs = [inputs[0]]
|
44 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
45 |
with torch.no_grad():
|
46 |
-
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
|
47 |
imgt_pred = padder.unpad(imgt_pred)
|
48 |
in_1 = padder.unpad(in_1)
|
49 |
outputs += [imgt_pred, in_1]
|
|
|
27 |
inputs = []
|
28 |
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
29 |
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
30 |
+
if device == 'cpu':
|
31 |
+
# Do not resize in cpu mode
|
32 |
+
anchor_resolution = 8192*8192
|
33 |
+
anchor_memory = 1
|
34 |
+
anchor_memory_bias = 0
|
35 |
+
vram_avail = 1
|
36 |
+
elif device == 'cuda':
|
37 |
+
anchor_resolution = 1024 * 512
|
38 |
+
anchor_memory = 1500 * 1024**2
|
39 |
+
anchor_memory_bias = 2500 * 1024**2
|
40 |
+
vram_avail = torch.cuda.get_device_properties(device).total_memory
|
41 |
+
|
42 |
+
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
|
43 |
+
scale = 1 if scale > 1 else scale
|
44 |
+
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
|
45 |
+
if scale < 1:
|
46 |
+
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
|
47 |
+
padding = int(16 / scale)
|
48 |
+
padder = InputPadder(inputs[0].shape, padding)
|
49 |
while True:
|
50 |
ret, frame = vcap.read()
|
51 |
if ret is False:
|
|
|
61 |
outputs = [inputs[0]]
|
62 |
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
|
63 |
with torch.no_grad():
|
64 |
+
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
|
65 |
imgt_pred = padder.unpad(imgt_pred)
|
66 |
in_1 = padder.unpad(in_1)
|
67 |
outputs += [imgt_pred, in_1]
|
utils.py
CHANGED
@@ -227,3 +227,21 @@ def warp(img, flow):
|
|
227 |
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
228 |
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
|
229 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
228 |
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
|
229 |
return output
|
230 |
+
|
231 |
+
def check_dim_and_resize(tensor_list):
|
232 |
+
shape_list = []
|
233 |
+
for t in tensor_list:
|
234 |
+
shape_list.append(t.shape[2:])
|
235 |
+
|
236 |
+
if len(set(shape_list)) > 1:
|
237 |
+
desired_shape = shape_list[0]
|
238 |
+
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
|
239 |
+
|
240 |
+
resize_tensor_list = []
|
241 |
+
for t in tensor_list:
|
242 |
+
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
|
243 |
+
|
244 |
+
tensor_list = resize_tensor_list
|
245 |
+
|
246 |
+
return tensor_list
|
247 |
+
|