riiswa commited on
Commit
b471ab8
1 Parent(s): 3735566
Files changed (1) hide show
  1. interpretable.py +2 -0
interpretable.py CHANGED
@@ -33,6 +33,8 @@ class InterpretablePolicyExtractor:
33
  dataset["train_label"] = dataset["train_label"][:, None]
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
 
 
36
  for k,v in dataset.items():
37
  print(k, v.shape, v.dtype)
38
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
 
33
  dataset["train_label"] = dataset["train_label"][:, None]
34
  if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
  dataset["test_label"] = dataset["test_label"][:, None]
36
+ dataset["train_input"] = dataset["train_input"].float()
37
+ dataset["test_input"] = dataset["test_input"].float()
38
  for k,v in dataset.items():
39
  print(k, v.shape, v.dtype)
40
  return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)