Spaces:
Runtime error
Runtime error
File size: 1,181 Bytes
0dfe33d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from . import auditory_stream, visual_stream
import chainer
class ResNet18(chainer.Chain):
def __init__(self):
super(ResNet18, self).__init__(
aud=auditory_stream.ResNet18(),
vis=visual_stream.ResNet18(),
fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()),
)
def __call__(self, x):
h = [
self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))),
chainer.functions.expand_dims(
chainer.functions.sum(
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0
),
0,
),
]
for i in range(256, x[1].shape[0], 256):
h[1] += chainer.functions.expand_dims(
chainer.functions.sum(
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))),
0,
),
0,
)
h[1] /= x[1].shape[0]
return chainer.cuda.to_cpu(
(
(chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2
).data[0]
)
|