Abdullah-Nazhat commited on
Commit
e59cbdd
·
verified ·
1 Parent(s): 2cf79e2

Update interactor.py

Browse files
Files changed (1) hide show
  1. interactor.py +13 -6
interactor.py CHANGED
@@ -9,15 +9,22 @@ class MemoryUnit(nn.Module):
9
  def __init__(self,dim):
10
  super().__init__()
11
 
12
- self.mem = nn.Linear(dim,dim)
13
- self.norm = nn.LayerNorm(dim)
14
-
 
 
15
 
16
  def forward(self, x):
17
 
18
- x = self.norm(x)
19
- x = self.mem(x)
20
- x = self.norm(x)
 
 
 
 
 
21
 
22
  return x
23
 
 
9
  def __init__(self,dim):
10
  super().__init__()
11
 
12
+
13
+ self.norm_token = nn.LayerNorm(dim)
14
+ self.proj_1 = nn.Linear(dim,dim)
15
+ self.proj_2 = nn.Linear(dim,dim)
16
+ self.proj_3 = nn.Linear(dim,dim)
17
 
18
  def forward(self, x):
19
 
20
+ x = self.norm_token(x)
21
+ u, v = x, x
22
+ u = self.proj_1(u)
23
+ u = self.norm_token(u)
24
+ v = self.proj_2(v)
25
+ g = u * v
26
+ x = self.proj_3(g)
27
+ x = self.norm_token(x)
28
 
29
  return x
30