razmars commited on
Commit
a623e9c
·
verified ·
1 Parent(s): 25f53ff

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +18 -2
modeling_super_linear.py CHANGED
@@ -210,8 +210,8 @@ class RLinear(nn.Module):
210
  new_W = new_W * final_scaling
211
 
212
  self.zero_shot_Linear = new_W
213
- else:
214
- W = self.Linear.weight.detach()
215
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
216
 
217
  # resize H → self.horizon and W → new_lookback
@@ -223,6 +223,22 @@ class RLinear(nn.Module):
223
  )[0, 0] # drop the two singleton dims
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
 
228
 
 
210
  new_W = new_W * final_scaling
211
 
212
  self.zero_shot_Linear = new_W
213
+ elif mode ==2:
214
+ W = self.Linear.weight.detach()
215
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
216
 
217
  # resize H → self.horizon and W → new_lookback
 
223
  )[0, 0] # drop the two singleton dims
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
+ else:
227
+ W = self.Linear.weight.detach()
228
+ W = W[:, -new_lookback:]
229
+ W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
230
+
231
+ # resize H → self.horizon and W → new_lookback
232
+ new_W = F.interpolate(
233
+ W4d,
234
+ size=(self.seq_len , new_lookback), # (H_out, W_out)
235
+ mode='bilinear',
236
+ align_corners=False
237
+ )[0, 0] # drop the two singleton dims
238
+
239
+ W_now = torch.cat((W, new_W), dim=1)
240
+ self.zero_shot_Linear = new_W
241
+
242
 
243
 
244