File size: 434 Bytes
d6ec83b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from lambda_networks import LambdaLayer
from lambda_networks import RLambdaLayer

layer = LambdaLayer(dim=8, dim_out=8, r=23, dim_k=16, heads=4, dim_u=4)
rlayer = RLambdaLayer(dim=8, dim_out=8, r=23, dim_k=16, heads=4, dim_u=4, recurrence=3)
if __name__ == "__main__":
    x = torch.randn(1, 8, 64, 64, requires_grad=True)
    y = layer(x)
    z = rlayer(x)
    print(y.shape, z.shape)
    z.sum().backward()