razmars commited on
Commit
29025e3
·
verified ·
1 Parent(s): 82634ad

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +17 -17
modeling_super_linear.py CHANGED
@@ -379,9 +379,9 @@ class superLinear(nn.Module):
379
 
380
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
381
 
382
- path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
- dirs = os.listdir(path)
384
- checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
385
 
386
  if self.freq_experts == "all":
387
  self.freq_experts = []
@@ -425,11 +425,11 @@ class superLinear(nn.Module):
425
 
426
  if configs.misc_moe>0:
427
  if configs.misc_moe == 1:
428
- #print("Creating misc expert")
429
  self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
430
  else:
431
  for i in range(configs.misc_moe):
432
- #print(f"Creating misc expert {i}")
433
  self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
434
 
435
 
@@ -437,18 +437,18 @@ class superLinear(nn.Module):
437
  self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
438
  self.dropout = nn.Dropout(configs.dropout)
439
 
440
- if configs.load_weights:
441
- print(f"Loading weights from {path}")
442
- path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
443
- if os.path.exists(path):
444
- checkpoint = torch.load(path)
445
- print(len(self.experts.keys()))
446
- print(self.experts.keys())
447
- print(self.state_dict().keys())
448
- print(checkpoint.keys())
449
- self.load_state_dict(checkpoint)
450
- else:
451
- print(f"Path {path} does not exist. Skipping loading weights.")
452
 
453
 
454
  def map_to_cycle(self, freq):
 
379
 
380
  self.layer_dict = {'DLinear': DLinear, 'Linear': Linear, 'NLinear': NLinear, 'RLinear': RLinear}
381
 
382
+ # path = configs.linear_checkpoints_path + configs.linear_checkpoints_dir
383
+ # dirs = os.listdir(path)
384
+ # checkpoints_paths = [path + "/" + d + "/" + "checkpoint.pth" for d in dirs]
385
 
386
  if self.freq_experts == "all":
387
  self.freq_experts = []
 
425
 
426
  if configs.misc_moe>0:
427
  if configs.misc_moe == 1:
428
+ print("Creating misc expert")
429
  self.experts["misc"] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
430
  else:
431
  for i in range(configs.misc_moe):
432
+ print(f"Creating misc expert {i}")
433
  self.experts["misc_"+str(i)] = self.layer_dict[self.layer_type](self.seq_len, self.pred_len)
434
 
435
 
 
437
  self.moe = SparseNoisyMoE(configs, experts=self.experts.values())
438
  self.dropout = nn.Dropout(configs.dropout)
439
 
440
+ # if configs.load_weights:
441
+ # print(f"Loading weights from {path}")
442
+ # path = configs.load_weights_path + "" + configs.load_weights_dir + "/" + "checkpoint.pth"
443
+ # if os.path.exists(path):
444
+ # checkpoint = torch.load(path)
445
+ # print(len(self.experts.keys()))
446
+ # print(self.experts.keys())
447
+ # print(self.state_dict().keys())
448
+ # print(checkpoint.keys())
449
+ # self.load_state_dict(checkpoint)
450
+ # else:
451
+ # print(f"Path {path} does not exist. Skipping loading weights.")
452
 
453
 
454
  def map_to_cycle(self, freq):