| import torch | |
| def fasterrcnn_reshape_transform(x): | |
| target_size = x['pool'].size()[-2:] | |
| activations = [] | |
| for key, value in x.items(): | |
| activations.append( | |
| torch.nn.functional.interpolate( | |
| torch.abs(value), | |
| target_size, | |
| mode='bilinear')) | |
| activations = torch.cat(activations, axis=1) | |
| return activations | |
| def swinT_reshape_transform(tensor, height=7, width=7): | |
| result = tensor.reshape(tensor.size(0), | |
| height, width, tensor.size(2)) | |
| # Bring the channels to the first dimension, | |
| # like in CNNs. | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| def vit_reshape_transform(tensor, height=14, width=14): | |
| result = tensor[:, 1:, :].reshape(tensor.size(0), | |
| height, width, tensor.size(2)) | |
| # Bring the channels to the first dimension, | |
| # like in CNNs. | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |