traits-prediction / src /audiovisual_stream.py
jvcanavarro's picture
Add source code
0dfe33d
from . import auditory_stream, visual_stream
import chainer
class ResNet18(chainer.Chain):
def __init__(self):
super(ResNet18, self).__init__(
aud=auditory_stream.ResNet18(),
vis=visual_stream.ResNet18(),
fc=chainer.links.Linear(512, 5, initialW=chainer.initializers.HeNormal()),
)
def __call__(self, x):
h = [
self.aud(chainer.Variable(chainer.cuda.to_cpu(x[0]))),
chainer.functions.expand_dims(
chainer.functions.sum(
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][:256]))), 0
),
0,
),
]
for i in range(256, x[1].shape[0], 256):
h[1] += chainer.functions.expand_dims(
chainer.functions.sum(
self.vis(chainer.Variable(chainer.cuda.to_cpu(x[1][i : i + 256]))),
0,
),
0,
)
h[1] /= x[1].shape[0]
return chainer.cuda.to_cpu(
(
(chainer.functions.tanh(self.fc(chainer.functions.concat(h))) + 1) / 2
).data[0]
)