Vincentqyw
fix: roma
8b973ee
raw
history blame
No virus
23.4 kB
import torch
import torch.nn.functional as F
try:
import cupy
except:
print("Cupy not found, local correlation will not work")
import re
from ..dkm import ConvRefiner
class Stream:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
stream = torch.cuda.current_stream(device=device).cuda_stream
else:
stream = None
kernel_Correlation_rearrange = """
extern "C" __global__ void kernel_Correlation_rearrange(
const int n,
const float* input,
float* output
) {
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
if (intIndex >= n) {
return;
}
int intSample = blockIdx.z;
int intChannel = blockIdx.y;
float dblValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
__syncthreads();
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = dblValue;
}
"""
kernel_Correlation_updateOutput = """
extern "C" __global__ void kernel_Correlation_updateOutput(
const int n,
const float* rbot0,
const float* rbot1,
float* top
) {
extern __shared__ char patch_data_char[];
float *patch_data = (float *)patch_data_char;
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
int x1 = blockIdx.x + 4;
int y1 = blockIdx.y + 4;
int item = blockIdx.z;
int ch_off = threadIdx.x;
// Load 3D patch into shared shared memory
for (int j = 0; j < 1; j++) { // HEIGHT
for (int i = 0; i < 1; i++) { // WIDTH
int ji_off = (j + i) * SIZE_3(rbot0);
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
int idxPatchData = ji_off + ch;
patch_data[idxPatchData] = rbot0[idx1];
}
}
}
__syncthreads();
__shared__ float sum[32];
// Compute correlation
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
sum[ch_off] = 0;
int s2o = top_channel % 9 - 4;
int s2p = top_channel / 9 - 4;
for (int j = 0; j < 1; j++) { // HEIGHT
for (int i = 0; i < 1; i++) { // WIDTH
int ji_off = (j + i) * SIZE_3(rbot0);
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
int x2 = x1 + s2o;
int y2 = y1 + s2p;
int idxPatchData = ji_off + ch;
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
}
}
}
__syncthreads();
if (ch_off == 0) {
float total_sum = 0;
for (int idx = 0; idx < 32; idx++) {
total_sum += sum[idx];
}
const int sumelems = SIZE_3(rbot0);
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
}
}
}
"""
kernel_Correlation_updateGradFirst = """
#define ROUND_OFF 50000
extern "C" __global__ void kernel_Correlation_updateGradFirst(
const int n,
const int intSample,
const float* rbot0,
const float* rbot1,
const float* gradOutput,
float* gradFirst,
float* gradSecond
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
int n = intIndex % SIZE_1(gradFirst); // channels
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
// round_off is a trick to enable integer division with ceil, even for negative numbers
// We use a large offset, for the inner part not to become negative.
const int round_off = ROUND_OFF;
const int round_off_s1 = round_off;
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
// Same here:
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
float sum = 0;
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
xmin = max(0,xmin);
xmax = min(SIZE_3(gradOutput)-1,xmax);
ymin = max(0,ymin);
ymax = min(SIZE_2(gradOutput)-1,ymax);
for (int p = -4; p <= 4; p++) {
for (int o = -4; o <= 4; o++) {
// Get rbot1 data:
int s2o = o;
int s2p = p;
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
// Index offset for gradOutput in following loops:
int op = (p+4) * 9 + (o+4); // index[o,p]
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
for (int y = ymin; y <= ymax; y++) {
for (int x = xmin; x <= xmax; x++) {
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
sum += gradOutput[idxgradOutput] * bot1tmp;
}
}
}
}
}
const int sumelems = SIZE_1(gradFirst);
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
} }
"""
kernel_Correlation_updateGradSecond = """
#define ROUND_OFF 50000
extern "C" __global__ void kernel_Correlation_updateGradSecond(
const int n,
const int intSample,
const float* rbot0,
const float* rbot1,
const float* gradOutput,
float* gradFirst,
float* gradSecond
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
int n = intIndex % SIZE_1(gradSecond); // channels
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
// round_off is a trick to enable integer division with ceil, even for negative numbers
// We use a large offset, for the inner part not to become negative.
const int round_off = ROUND_OFF;
const int round_off_s1 = round_off;
float sum = 0;
for (int p = -4; p <= 4; p++) {
for (int o = -4; o <= 4; o++) {
int s2o = o;
int s2p = p;
//Get X,Y ranges and clamp
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
// Same here:
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
xmin = max(0,xmin);
xmax = min(SIZE_3(gradOutput)-1,xmax);
ymin = max(0,ymin);
ymax = min(SIZE_2(gradOutput)-1,ymax);
// Get rbot0 data:
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
// Index offset for gradOutput in following loops:
int op = (p+4) * 9 + (o+4); // index[o,p]
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
for (int y = ymin; y <= ymax; y++) {
for (int x = xmin; x <= xmax; x++) {
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
sum += gradOutput[idxgradOutput] * bot0tmp;
}
}
}
}
}
const int sumelems = SIZE_1(gradSecond);
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
} }
"""
def cupy_kernel(strFunction, objectVariables):
strKernel = globals()[strFunction]
while True:
objectMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel)
if objectMatch is None:
break
intArg = int(objectMatch.group(2))
strTensor = objectMatch.group(4)
intSizes = objectVariables[strTensor].size()
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
while True:
objectMatch = re.search(r"(VALUE_)([0-4])(\()([^\)]+)(\))", strKernel)
if objectMatch is None:
break
intArgs = int(objectMatch.group(2))
strArgs = objectMatch.group(4).split(",")
strTensor = strArgs[0]
intStrides = objectVariables[strTensor].stride()
strIndex = [
"(("
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip()
+ ")*"
+ str(intStrides[intArg])
+ ")"
for intArg in range(intArgs)
]
strKernel = strKernel.replace(
objectMatch.group(0), strTensor + "[" + str.join("+", strIndex) + "]"
)
return strKernel
try:
@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
return cupy.RawModule(code=strKernel).get_function(strFunction)
except:
pass
class _FunctionCorrelation(torch.autograd.Function):
@staticmethod
def forward(self, first, second):
rbot0 = first.new_zeros(
[first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
)
rbot1 = first.new_zeros(
[first.size(0), first.size(2) + 8, first.size(3) + 8, first.size(1)]
)
self.save_for_backward(first, second, rbot0, rbot1)
first = first.contiguous()
second = second.contiguous()
output = first.new_zeros([first.size(0), 81, first.size(2), first.size(3)])
if first.is_cuda == True:
n = first.size(2) * first.size(3)
cupy_launch(
"kernel_Correlation_rearrange",
cupy_kernel(
"kernel_Correlation_rearrange", {"input": first, "output": rbot0}
),
)(
grid=tuple([int((n + 16 - 1) / 16), first.size(1), first.size(0)]),
block=tuple([16, 1, 1]),
args=[n, first.data_ptr(), rbot0.data_ptr()],
stream=Stream,
)
n = second.size(2) * second.size(3)
cupy_launch(
"kernel_Correlation_rearrange",
cupy_kernel(
"kernel_Correlation_rearrange", {"input": second, "output": rbot1}
),
)(
grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
block=tuple([16, 1, 1]),
args=[n, second.data_ptr(), rbot1.data_ptr()],
stream=Stream,
)
n = output.size(1) * output.size(2) * output.size(3)
cupy_launch(
"kernel_Correlation_updateOutput",
cupy_kernel(
"kernel_Correlation_updateOutput",
{"rbot0": rbot0, "rbot1": rbot1, "top": output},
),
)(
grid=tuple([output.size(3), output.size(2), output.size(0)]),
block=tuple([32, 1, 1]),
shared_mem=first.size(1) * 4,
args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()],
stream=Stream,
)
elif first.is_cuda == False:
raise NotImplementedError()
return output
@staticmethod
def backward(self, gradOutput):
first, second, rbot0, rbot1 = self.saved_tensors
gradOutput = gradOutput.contiguous()
assert gradOutput.is_contiguous() == True
gradFirst = (
first.new_zeros(
[first.size(0), first.size(1), first.size(2), first.size(3)]
)
if self.needs_input_grad[0] == True
else None
)
gradSecond = (
first.new_zeros(
[first.size(0), first.size(1), first.size(2), first.size(3)]
)
if self.needs_input_grad[1] == True
else None
)
if first.is_cuda == True:
if gradFirst is not None:
for intSample in range(first.size(0)):
n = first.size(1) * first.size(2) * first.size(3)
cupy_launch(
"kernel_Correlation_updateGradFirst",
cupy_kernel(
"kernel_Correlation_updateGradFirst",
{
"rbot0": rbot0,
"rbot1": rbot1,
"gradOutput": gradOutput,
"gradFirst": gradFirst,
"gradSecond": None,
},
),
)(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
n,
intSample,
rbot0.data_ptr(),
rbot1.data_ptr(),
gradOutput.data_ptr(),
gradFirst.data_ptr(),
None,
],
stream=Stream,
)
if gradSecond is not None:
for intSample in range(first.size(0)):
n = first.size(1) * first.size(2) * first.size(3)
cupy_launch(
"kernel_Correlation_updateGradSecond",
cupy_kernel(
"kernel_Correlation_updateGradSecond",
{
"rbot0": rbot0,
"rbot1": rbot1,
"gradOutput": gradOutput,
"gradFirst": None,
"gradSecond": gradSecond,
},
),
)(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
n,
intSample,
rbot0.data_ptr(),
rbot1.data_ptr(),
gradOutput.data_ptr(),
None,
gradSecond.data_ptr(),
],
stream=Stream,
)
elif first.is_cuda == False:
raise NotImplementedError()
return gradFirst, gradSecond
class _FunctionCorrelationTranspose(torch.autograd.Function):
@staticmethod
def forward(self, input, second):
rbot0 = second.new_zeros(
[second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
)
rbot1 = second.new_zeros(
[second.size(0), second.size(2) + 8, second.size(3) + 8, second.size(1)]
)
self.save_for_backward(input, second, rbot0, rbot1)
input = input.contiguous()
second = second.contiguous()
output = second.new_zeros(
[second.size(0), second.size(1), second.size(2), second.size(3)]
)
if second.is_cuda == True:
n = second.size(2) * second.size(3)
cupy_launch(
"kernel_Correlation_rearrange",
cupy_kernel(
"kernel_Correlation_rearrange", {"input": second, "output": rbot1}
),
)(
grid=tuple([int((n + 16 - 1) / 16), second.size(1), second.size(0)]),
block=tuple([16, 1, 1]),
args=[n, second.data_ptr(), rbot1.data_ptr()],
stream=Stream,
)
for intSample in range(second.size(0)):
n = second.size(1) * second.size(2) * second.size(3)
cupy_launch(
"kernel_Correlation_updateGradFirst",
cupy_kernel(
"kernel_Correlation_updateGradFirst",
{
"rbot0": rbot0,
"rbot1": rbot1,
"gradOutput": input,
"gradFirst": output,
"gradSecond": None,
},
),
)(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
n,
intSample,
rbot0.data_ptr(),
rbot1.data_ptr(),
input.data_ptr(),
output.data_ptr(),
None,
],
stream=Stream,
)
elif second.is_cuda == False:
raise NotImplementedError()
return output
@staticmethod
def backward(self, gradOutput):
input, second, rbot0, rbot1 = self.saved_tensors
gradOutput = gradOutput.contiguous()
gradInput = (
input.new_zeros(
[input.size(0), input.size(1), input.size(2), input.size(3)]
)
if self.needs_input_grad[0] == True
else None
)
gradSecond = (
second.new_zeros(
[second.size(0), second.size(1), second.size(2), second.size(3)]
)
if self.needs_input_grad[1] == True
else None
)
if second.is_cuda == True:
if gradInput is not None or gradSecond is not None:
n = second.size(2) * second.size(3)
cupy_launch(
"kernel_Correlation_rearrange",
cupy_kernel(
"kernel_Correlation_rearrange",
{"input": gradOutput, "output": rbot0},
),
)(
grid=tuple(
[int((n + 16 - 1) / 16), gradOutput.size(1), gradOutput.size(0)]
),
block=tuple([16, 1, 1]),
args=[n, gradOutput.data_ptr(), rbot0.data_ptr()],
stream=Stream,
)
if gradInput is not None:
n = gradInput.size(1) * gradInput.size(2) * gradInput.size(3)
cupy_launch(
"kernel_Correlation_updateOutput",
cupy_kernel(
"kernel_Correlation_updateOutput",
{"rbot0": rbot0, "rbot1": rbot1, "top": gradInput},
),
)(
grid=tuple(
[gradInput.size(3), gradInput.size(2), gradInput.size(0)]
),
block=tuple([32, 1, 1]),
shared_mem=gradOutput.size(1) * 4,
args=[n, rbot0.data_ptr(), rbot1.data_ptr(), gradInput.data_ptr()],
stream=Stream,
)
if gradSecond is not None:
for intSample in range(second.size(0)):
n = second.size(1) * second.size(2) * second.size(3)
cupy_launch(
"kernel_Correlation_updateGradSecond",
cupy_kernel(
"kernel_Correlation_updateGradSecond",
{
"rbot0": rbot0,
"rbot1": rbot1,
"gradOutput": input,
"gradFirst": None,
"gradSecond": gradSecond,
},
),
)(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
n,
intSample,
rbot0.data_ptr(),
rbot1.data_ptr(),
input.data_ptr(),
None,
gradSecond.data_ptr(),
],
stream=Stream,
)
elif second.is_cuda == False:
raise NotImplementedError()
return gradInput, gradSecond
def FunctionCorrelation(reference_features, query_features):
return _FunctionCorrelation.apply(reference_features, query_features)
class ModuleCorrelation(torch.nn.Module):
def __init__(self):
super(ModuleCorrelation, self).__init__()
def forward(self, tensorFirst, tensorSecond):
return _FunctionCorrelation.apply(tensorFirst, tensorSecond)
def FunctionCorrelationTranspose(reference_features, query_features):
return _FunctionCorrelationTranspose.apply(reference_features, query_features)
class ModuleCorrelationTranspose(torch.nn.Module):
def __init__(self):
super(ModuleCorrelationTranspose, self).__init__()
def forward(self, tensorFirst, tensorSecond):
return _FunctionCorrelationTranspose.apply(tensorFirst, tensorSecond)
class LocalCorr(ConvRefiner):
def forward(self, x, y, flow):
"""Computes the relative refining displacement in pixels for a given image x,y and a coarse flow-field between them
Args:
x ([type]): [description]
y ([type]): [description]
flow ([type]): [description]
Returns:
[type]: [description]
"""
with torch.no_grad():
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False)
corr = FunctionCorrelation(x, x_hat)
d = self.block1(corr)
d = self.hidden_blocks(d)
d = self.out_conv(d)
certainty, displacement = d[:, :-2], d[:, -2:]
return certainty, displacement
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn(2, 128, 32, 32).to(device)
y = torch.randn(2, 128, 32, 32).to(device)
local_corr = LocalCorr(in_dim=81, hidden_dim=81 * 4)
z = local_corr(x, y)
print("hej")