riiswa commited on
Commit
3735566
1 Parent(s): a364ed7
Files changed (1) hide show
  1. interpretable.py +1 -1
interpretable.py CHANGED
@@ -34,7 +34,7 @@ class InterpretablePolicyExtractor:
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
  for k,v in dataset.items():
37
- print(k, v.shape)
38
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
39
 
40
  def forward(self, observation):
 
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
  for k,v in dataset.items():
37
+ print(k, v.shape, v.dtype)
38
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
39
 
40
  def forward(self, observation):