Spaces:
Runtime error
Runtime error
randommm
commited on
Commit
•
6e701f9
1
Parent(s):
3e23aec
update
Browse files
facility_location/env/obs_extractor.py
CHANGED
@@ -117,20 +117,21 @@ class ObsExtractor:
|
|
117 |
return obs
|
118 |
|
119 |
def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
|
120 |
facility = self._flc.get_current_solution().astype(np.float32)
|
121 |
distance = self._flc.get_current_distance().astype(np.float32)
|
122 |
distance = distance / np.max(distance)
|
123 |
cost = self._flc.get_current_cost().astype(np.float32)
|
124 |
cost = cost / np.max(cost)
|
125 |
gain, loss = self._flc.get_gain_and_loss()
|
126 |
-
gain = gain / np.max(gain)
|
127 |
-
loss = loss / np.max(loss)
|
128 |
dynamic_node_features = np.stack([
|
129 |
facility,
|
130 |
distance[:,0],
|
131 |
distance[:,1],
|
132 |
cost[:,0],
|
133 |
-
cost[:,1],
|
134 |
gain,
|
135 |
loss,
|
136 |
], axis=-1)
|
|
|
117 |
return obs
|
118 |
|
119 |
def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
120 |
+
EPS = 1e-8
|
121 |
facility = self._flc.get_current_solution().astype(np.float32)
|
122 |
distance = self._flc.get_current_distance().astype(np.float32)
|
123 |
distance = distance / np.max(distance)
|
124 |
cost = self._flc.get_current_cost().astype(np.float32)
|
125 |
cost = cost / np.max(cost)
|
126 |
gain, loss = self._flc.get_gain_and_loss()
|
127 |
+
gain = gain / (np.max(gain) + EPS)
|
128 |
+
loss = loss / (np.max(loss) + EPS)
|
129 |
dynamic_node_features = np.stack([
|
130 |
facility,
|
131 |
distance[:,0],
|
132 |
distance[:,1],
|
133 |
cost[:,0],
|
134 |
+
cost[:,1],
|
135 |
gain,
|
136 |
loss,
|
137 |
], axis=-1)
|