from . import auditory_stream, visual_stream import chainer class ResNet18(chainer.Chain): def __init__(self): super(ResNet18, self).__init__( aud=auditory_stream.ResNet18(), vis=visual_stream.ResNet18(), fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()), ) def __call__(self, x): h = [ self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))), chainer.functions.expand_dims( chainer.functions.sum( self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0 ), 0, ), ] for i in range(256, x[1].shape[0], 256): h[1] += chainer.functions.expand_dims( chainer.functions.sum( self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))), 0, ), 0, ) h[1] /= x[1].shape[0] return chainer.cuda.to_cpu( ( (chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2 ).data[0] )