debugging
Browse files- interpretable.py +2 -0
interpretable.py
CHANGED
@@ -33,6 +33,8 @@ class InterpretablePolicyExtractor:
|
|
33 |
dataset["train_label"] = dataset["train_label"][:, None]
|
34 |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
|
|
|
|
36 |
for k,v in dataset.items():
|
37 |
print(k, v.shape, v.dtype)
|
38 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
|
|
33 |
dataset["train_label"] = dataset["train_label"][:, None]
|
34 |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
36 |
+
dataset["train_input"] = dataset["train_input"].float()
|
37 |
+
dataset["test_input"] = dataset["test_input"].float()
|
38 |
for k,v in dataset.items():
|
39 |
print(k, v.shape, v.dtype)
|
40 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|