debugging
Browse files- interpretable.py +2 -0
interpretable.py
CHANGED
@@ -33,6 +33,8 @@ class InterpretablePolicyExtractor:
|
|
33 |
dataset["train_label"] = dataset["train_label"][:, None]
|
34 |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
|
|
|
|
36 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
37 |
|
38 |
def forward(self, observation):
|
|
|
33 |
dataset["train_label"] = dataset["train_label"][:, None]
|
34 |
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
36 |
+
for k,v in dataset.items():
|
37 |
+
print(k, v.shape)
|
38 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
39 |
|
40 |
def forward(self, observation):
|