razmars commited on
Commit
e1188db
·
verified ·
1 Parent(s): 266c1c3

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -18
modeling_super_linear.py CHANGED
@@ -200,7 +200,7 @@ class RLinear(nn.Module):
200
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
201
  self.zero_shot_Linear = None
202
 
203
- def transform_model(self,x,new_lookback,mode):
204
  if mode == 1:
205
  W = self.Linear.weight.detach()
206
  new_W = W[:, -new_lookback:]
@@ -210,7 +210,7 @@ class RLinear(nn.Module):
210
  new_W = new_W * final_scaling
211
 
212
  self.zero_shot_Linear = new_W
213
- elif mode ==2:
214
  W = self.Linear.weight.detach()
215
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
216
 
@@ -223,22 +223,8 @@ class RLinear(nn.Module):
223
  )[0, 0] # drop the two singleton dims
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
- else:
227
- W = self.Linear.weight.detach() # (out_features, seq_len)
228
-
229
- W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
230
-
231
- # resize H → self.horizon and W → new_lookback
232
- new_W = F.interpolate(
233
- W4d,
234
- size=(self.horizon, new_lookback), # (H_out, W_out)
235
- mode='bilinear',
236
- align_corners=False
237
- )[0, 0] # drop the two singleton dims
238
 
239
- x = F.linear(x, new_W)
240
- self.zero_shot_Linear = W
241
- return x
242
 
243
 
244
 
@@ -249,7 +235,7 @@ class RLinear(nn.Module):
249
  #if self.zero_shot_Linear is None:
250
  #print(F"new Lookkback : {x.shape[1]}")
251
 
252
- x = self.transform_model(x,x.shape[1],3)
253
 
254
  x = x.clone()
255
  #x = x * (x.shape[1]/512)
 
200
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
201
  self.zero_shot_Linear = None
202
 
203
+ def transform_model(self,new_lookback,mode):
204
  if mode == 1:
205
  W = self.Linear.weight.detach()
206
  new_W = W[:, -new_lookback:]
 
210
  new_W = new_W * final_scaling
211
 
212
  self.zero_shot_Linear = new_W
213
+ else:
214
  W = self.Linear.weight.detach()
215
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
216
 
 
223
  )[0, 0] # drop the two singleton dims
224
 
225
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
226
+
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
228
 
229
 
230
 
 
235
  #if self.zero_shot_Linear is None:
236
  #print(F"new Lookkback : {x.shape[1]}")
237
 
238
+ x = self.transform_model(x.shape[1],3)
239
 
240
  x = x.clone()
241
  #x = x * (x.shape[1]/512)