manbeast3b commited on
Commit
ca25f61
·
verified ·
1 Parent(s): ab6455c

Update src/loss.py

Browse files
Files changed (1) hide show
  1. src/loss.py +1 -1
src/loss.py CHANGED
@@ -5,7 +5,7 @@ import output
5
  class LossSchedulerModel(torch.nn.Module):
6
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
7
  def forward(A,t,xT,e_prev):
8
- B=e_prev;assert t-len(B)==0;C=xT*A.wx[t]
9
  for(D,E)in zip(B,A.we[t]):C+=D*E
10
  return C.to(xT.dtype)
11
  class LossScheduler:
 
5
  class LossSchedulerModel(torch.nn.Module):
6
  def __init__(A,wx,we):super(LossSchedulerModel,A).__init__();assert len(wx.shape)==1 and len(we.shape)==2;B=wx.shape[0];assert B==we.shape[0]and B==we.shape[1];A.register_parameter('wx',torch.nn.Parameter(wx));A.register_parameter('we',torch.nn.Parameter(we))
7
  def forward(A,t,xT,e_prev):
8
+ B=e_prev;assert t-len(B)+1==0;C=xT*A.wx[t]
9
  for(D,E)in zip(B,A.we[t]):C+=D*E
10
  return C.to(xT.dtype)
11
  class LossScheduler: