Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,287 Bytes
135075d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
from basicsr.archs.gmflow.gmflow.gmflow import GMFlow
class FlowGenerator(nn.Module):
"""GM flow generation.
Args:
path (str): Pre-trained path. Default: None.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
"""
def __init__(self,
path=None,
requires_grad=False,):
super().__init__()
self.model = GMFlow()
if path != None:
weights = torch.load(
path, map_location=lambda storage, loc: storage)['model']
self.model.load_state_dict(weights, strict=True)
if not requires_grad:
self.model.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.model.train()
for param in self.parameters():
param.requires_grad = True
def forward(self, im1, im2,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1]):
"""Forward function.
Args:
im1 (Tensor): Input tensor with shape (n, c, h, w).
im2 (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
assert im1.shape == im2.shape
N, C, H, W = im1.shape
im1 = (im1 + 1) / 2 * 255
im2 = (im2 + 1) / 2 * 255
flow = self.model(im1, im2,
attn_splits_list=attn_splits_list,
corr_radius_list=corr_radius_list,
prop_radius_list=prop_radius_list,
pred_bidir_flow=False)['flow_preds'][-1]
# backward_flow = flow[N:]
return flow
if __name__ == '__main__':
h, w = 512, 512
# model = RAFT().cuda()
model = FlowGenerator(
load_path='../../weights/GMFlow/gmflow_sintel-0c07dcb3.pth').cuda()
model.eval()
print(model)
x = torch.randn((1, 3, h, w)).cuda()
y = torch.randn((1, 3, h, w)).cuda()
with torch.no_grad():
out = model(x, y)
pdb.set_trace()
print(out.shape)
|