razmars commited on
Commit
8eeb171
·
verified ·
1 Parent(s): 0d2aee9

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +33 -7
modeling_super_linear.py CHANGED
@@ -194,19 +194,45 @@ class NLinear(nn.Module):
194
  class RLinear(nn.Module):
195
  def __init__(self, input_len, output_len):
196
  super(RLinear, self).__init__()
197
- self.Linear = nn.Linear(input_len, output_len)
198
- self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def forward(self, x):
201
  # x: [Batch, Input length,Channel]
202
  x_shape = x.shape
203
  if len(x_shape) == 2:
204
  x = x.unsqueeze(-1)
205
- x = x.clone()
206
- x = self.revin_layer(x, 'norm')
207
-
208
- x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
209
- x = self.revin_layer(x, 'denorm')
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if len(x_shape) == 2:
211
  x = x.squeeze(-1)
212
  return x # to [Batch, Output length, Channel]
 
194
  class RLinear(nn.Module):
195
  def __init__(self, input_len, output_len):
196
  super(RLinear, self).__init__()
197
+ self.Linear = nn.Linear(input_len, output_len)
198
+ self.seq_len = input_len
199
+ self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
200
+ self.zero_shot_Linear = None
201
+
202
+ def transform_model(self,new_lookback):
203
+ W = self.Linear.weight.detach()
204
+ new_W = W[:, -new_lookback:]
205
+ original_norm = torch.norm(W, p=2)
206
+ new_norm = torch.norm(new_W, p=2)
207
+ final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
208
+ new_W = new_W * final_scaling
209
+
210
+ self.zero_shot_Linear = new_W
211
+
212
 
213
  def forward(self, x):
214
  # x: [Batch, Input length,Channel]
215
  x_shape = x.shape
216
  if len(x_shape) == 2:
217
  x = x.unsqueeze(-1)
218
+
219
+ B,L,V = x.shape
220
+ if L < self.seq_len and self.zero_shot_Linear is None:
221
+ print(F"New Lookback :{L}")
222
+ self.transform_model(L)
223
+
224
+ if L < self.seq_len:
225
+ x = x.clone()
226
+ x = self.revin_layer(x, 'norm')
227
+ x = F.linear(x, self.zero_shot_Linear).unsqueeze_(-1)
228
+ x = self.revin_layer(x, 'denorm')
229
+
230
+ else:
231
+ x = x.clone()
232
+ x = self.revin_layer(x, 'norm')
233
+ x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
234
+ x = self.revin_layer(x, 'denorm')
235
+
236
  if len(x_shape) == 2:
237
  x = x.squeeze(-1)
238
  return x # to [Batch, Output length, Channel]