File size: 1,047 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn


class MixedOperationRandom(nn.Module):
    def __init__(self, search_ops):
        super(MixedOperationRandom, self).__init__()
        self.ops = nn.ModuleList(search_ops)
        self.num_ops = len(search_ops)

    def forward(self, x, x_path=None):
        if x_path is None:
            output = sum(op(x) for op in self.ops) / self.num_ops
        else:
            assert isinstance(x_path, (int, float)) and 0 <= x_path < self.num_ops or isinstance(x_path, torch.Tensor)
            if isinstance(x_path, (int, float)):
                x_path = int(x_path)
                assert 0 <= x_path < self.num_ops
                output = self.ops[x_path](x)
            elif isinstance(x_path, torch.Tensor):
                assert x_path.size(0) == x.size(0), "batch_size should match length of y_idx"
                output = torch.cat(
                    [self.ops[int(x_path[i].item())](x.narrow(0, i, 1)) for i in range(x.size(0))], dim=0
                )
        return output