File size: 632 Bytes
c9baa67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from collections import OrderedDict

import torch
import torch.nn as nn


class OnesLayer(nn.Module):
    def __init__(self, size=None):
        super().__init__()
        self.size = size

    def forward(self, tensor):
        shape = list(tensor.shape)
        shape[1] = 1  # return only one channel

        if self.size is not None:
            shape[2], shape[3] = self.size

        return torch.ones(shape, dtype=torch.float32, device=tensor.device)


class UninformativeFeatures(torch.nn.Sequential):
    def __init__(self):
        super().__init__(OrderedDict([
            ('ones', OnesLayer(size=(1, 1))),
        ]))