riiswa commited on
Commit
a364ed7
1 Parent(s): b8e09b3
Files changed (1) hide show
  1. interpretable.py +2 -0
interpretable.py CHANGED
@@ -33,6 +33,8 @@ class InterpretablePolicyExtractor:
33
  dataset["train_label"] = dataset["train_label"][:, None]
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
 
 
36
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
37
 
38
  def forward(self, observation):
 
33
  dataset["train_label"] = dataset["train_label"][:, None]
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
+ for k,v in dataset.items():
37
+ print(k, v.shape)
38
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
39
 
40
  def forward(self, observation):