import torch def decode_infer(output, stride): # logging.info(torch.tensor(output.shape[0])) # logging.info(output.shape) # # bz is batch-size # bz = tuple(torch.tensor(output.shape[0])) # gridsize = tuple(torch.tensor(output.shape[-1])) # logging.info(gridsize) sh = torch.tensor(output.shape) bz = sh[0] gridsize = sh[-1] output = output.permute(0, 2, 3, 1) output = output.view(bz, gridsize, gridsize, self.gt_per_grid, 5+self.numclass) x1y1, x2y2, conf, prob = torch.split( output, [2, 2, 1, self.numclass], dim=4) shiftx = torch.arange(0, gridsize, dtype=torch.float32) shifty = torch.arange(0, gridsize, dtype=torch.float32) shifty, shiftx = torch.meshgrid([shiftx, shifty]) shiftx = shiftx.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid) shifty = shifty.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid) xy_grid = torch.stack([shiftx, shifty], dim=4).cuda() x1y1 = (xy_grid+0.5-torch.exp(x1y1))*stride x2y2 = (xy_grid+0.5+torch.exp(x2y2))*stride xyxy = torch.cat((x1y1, x2y2), dim=4) conf = torch.sigmoid(conf) prob = torch.sigmoid(prob) output = torch.cat((xyxy, conf, prob), 4) output = output.view(bz, -1, 5+self.numclass) return output