File size: 1,181 Bytes
0dfe33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from . import auditory_stream, visual_stream
import chainer


class ResNet18(chainer.Chain):
    def __init__(self):
        super(ResNet18, self).__init__(
            aud=auditory_stream.ResNet18(),
            vis=visual_stream.ResNet18(),
            fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()),
        )

    def __call__(self, x):
        h = [
            self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))),
            chainer.functions.expand_dims(
                chainer.functions.sum(
                    self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0
                ),
                0,
            ),
        ]

        for i in range(256, x[1].shape[0], 256):
            h[1] += chainer.functions.expand_dims(
                chainer.functions.sum(
                    self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))),
                    0,
                ),
                0,
            )

        h[1] /= x[1].shape[0]

        return chainer.cuda.to_cpu(
            (
                (chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2
            ).data[0]
        )