EmaadKhwaja
file upload
5d2263b
raw
history blame
1.38 kB
import torch.nn as nn
# for routing arguments into the functions of the reversible layer
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(
zip(routed_args, router[key])
):
new_f_args, new_g_args = map(
lambda route: ({key: val} if route else {}), routes
)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route={}, layer_dropout=0.0):
super().__init__()
assert all(
len(route) == len(layers) for route in args_route.values()
), "each argument route map must have the same depth as the number of sequential layers"
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x