Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
@@ -5,7 +5,7 @@ import output
|
|
5 |
class LossSchedulerModel(torch.nn.Module):
|
6 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
7 |
def forward(A,t,xT,e_prev):
|
8 |
-
B=e_prev;assert t-len(B)==0;C=xT*A.wx[t]
|
9 |
for(D,E)in zip(B,A.we[t]):C+=D*E
|
10 |
return C.to(xT.dtype)
|
11 |
class LossScheduler:
|
|
|
5 |
class LossSchedulerModel(torch.nn.Module):
|
6 |
def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
|
7 |
def forward(A,t,xT,e_prev):
|
8 |
+
B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
|
9 |
for(D,E)in zip(B,A.we[t]):C+=D*E
|
10 |
return C.to(xT.dtype)
|
11 |
class LossScheduler:
|