Spaces:
Runtime error
Runtime error
from . import auditory_stream, visual_stream | |
import chainer | |
class ResNet18(chainer.Chain): | |
def __init__(self): | |
super(ResNet18, self).__init__( | |
aud=auditory_stream.ResNet18(), | |
vis=visual_stream.ResNet18(), | |
fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()), | |
) | |
def __call__(self, x): | |
h = [ | |
self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))), | |
chainer.functions.expand_dims( | |
chainer.functions.sum( | |
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0 | |
), | |
0, | |
), | |
] | |
for i in range(256, x[1].shape[0], 256): | |
h[1] += chainer.functions.expand_dims( | |
chainer.functions.sum( | |
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))), | |
0, | |
), | |
0, | |
) | |
h[1] /= x[1].shape[0] | |
return chainer.cuda.to_cpu( | |
( | |
(chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2 | |
).data[0] | |
) | |