MFLP / facility_location /env /obs_extractor.py
苏泓源
update
a257639
raw
history blame
7.84 kB
from typing import Dict, Tuple, Text
import numpy as np
from facility_location.env.facility_location_client import FacilityLocationClient
from facility_location.utils.config import Config
class ObsExtractor:
def __init__(self, cfg: Config, flc: FacilityLocationClient, node_range: int, edge_range: int):
self.cfg = cfg
self._flc = flc
self._node_range = node_range
self._edge_range = edge_range
self._construct_virtual_node_feature()
self._construct_node_features()
self._construct_action_mask()
def _construct_virtual_node_feature(self) -> None:
virtual_node_facility = 0
virtual_node_distance_min = 0
virtual_node_distance_sub_min = 0
virtual_node_cost_min = 0
virtual_node_cost_sub_min = 0
virtual_gain = 0
virtual_loss = 0
virtual_node_x = 0.5
virtual_node_y = 0.5
virtual_node_demand = 1
virtual_node_avg_distance = 0
virtual_node_avg_cost = 0
self._virtual_dynamic_node_feature = np.array([
virtual_node_facility,
virtual_node_distance_min,
virtual_node_distance_sub_min,
virtual_node_cost_min,
virtual_node_cost_sub_min,
virtual_gain,
virtual_loss,
], dtype=np.float32)
self._virtual_static_node_feature = np.array([
virtual_node_x,
virtual_node_y,
virtual_node_demand,
virtual_node_avg_distance,
virtual_node_avg_cost,
], dtype=np.float32)
self._virtual_node_feature = np.concatenate([
self._virtual_dynamic_node_feature,
self._virtual_static_node_feature,
], axis=-1)
def _construct_node_features(self) -> None:
self._node_features = np.zeros((self._node_range, self._virtual_node_feature.size), dtype=np.float32)
def _construct_action_mask(self) -> None:
self._old_facility_mask = np.full(self._node_range, False)
self._new_facility_mask = np.full(self._node_range, False)
def get_node_dim(self) -> int:
return self._virtual_node_feature.size
def reset(self) -> None:
self._compute_static_obs()
self._reset_node_features()
self._reset_action_mask()
def _compute_static_obs(self) -> None:
xy, demands, n, _ = self._flc.get_instance()
if n + 2 > self._node_range:
print(n, self._node_range)
# raise ValueError('The number of nodes exceeds the maximum limit.')
self._n = n
avg_distance, avg_cost = self._flc.get_avg_distance_and_cost()
avg_distance = avg_distance / np.max(avg_distance)
avg_cost = avg_cost / np.max(avg_cost)
self._static_node_features = np.stack([
xy[:, 0],
xy[:, 1],
demands,
avg_distance,
avg_cost,
], axis=-1).astype(np.float32)
static_adjacency_list = self._flc.get_static_adjacency_list()
obs_node_mask = np.full(1 + n, True)
self._obs_node_mask = self._pad_mask(obs_node_mask, self._node_range, 'nodes')
obs_static_edge_mask = np.full(n + static_adjacency_list.shape[0], True)
self._obs_static_edge_mask = self._pad_mask(obs_static_edge_mask, self._edge_range, 'edges')
self._static_adjacency_list = self._pad_edge(static_adjacency_list)
def _reset_node_features(self) -> None:
self._node_features[:, :] = 0
self._node_features[0] = self._virtual_node_feature
self._node_features[1:self._n+1, len(self._virtual_dynamic_node_feature):] = self._static_node_features
def _reset_action_mask(self) -> None:
self._old_facility_mask[:] = False
self._new_facility_mask[:] = False
def get_obs(self, t: int) -> Dict:
obs_nodes, obs_static_edges, obs_dynamic_edges, \
obs_node_mask, obs_static_edge_mask, obs_dynamic_edges_mask = self._get_obs_graph()
obs = {
'node_features': obs_nodes,
'static_adjacency_list': obs_static_edges,
'dynamic_adjacency_list': obs_dynamic_edges,
'node_mask': obs_node_mask,
'static_edge_mask': obs_static_edge_mask,
'dynamic_edge_mask': obs_dynamic_edges_mask,
}
return obs
def _get_obs_graph(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
facility = self._flc.get_current_solution().astype(np.float32)
distance = self._flc.get_current_distance().astype(np.float32)
distance = distance / np.max(distance)
cost = self._flc.get_current_cost().astype(np.float32)
cost = cost / np.max(cost)
gain, loss = self._flc.get_gain_and_loss()
gain = gain / np.max(gain)
loss = loss / np.max(loss)
dynamic_node_features = np.stack([
facility,
distance[:,0],
distance[:,1],
cost[:,0],
cost[:,1],
gain,
loss,
], axis=-1)
self._node_features[1:self._n+1, :len(self._virtual_dynamic_node_feature)] = dynamic_node_features
obs_nodes = self._node_features
obs_static_edges = self._static_adjacency_list
obs_dynamic_edges = self._flc.get_dynamic_adjacency_list()
# print(obs_dynamic_edges.shape)
obs_dynamic_edge_mask = np.full(obs_dynamic_edges.shape[0], True)
obs_node_mask = self._obs_node_mask
obs_static_edge_mask = self._obs_static_edge_mask
obs_dynamic_edges = self._pad_edge_wo_virtual(obs_dynamic_edges)
obs_dynamic_edge_mask = self._pad_mask(obs_dynamic_edge_mask, self._edge_range, 'edges')
return obs_nodes, obs_static_edges, obs_dynamic_edges, obs_node_mask, obs_static_edge_mask, obs_dynamic_edge_mask
# return obs_nodes, obs_static_edges, obs_node_mask, obs_edge_mask
def _get_obs_action_mask(self, t: int) -> Tuple[np.ndarray, np.ndarray]:
old_facility_mask, new_facility_mask = self._flc.get_facility_mask()
old_tabu_mask, new_tabu_mask = self._flc.get_tabu_mask(t)
self._old_facility_mask[1:self._n+1] = np.logical_and(old_facility_mask, old_tabu_mask)
self._new_facility_mask[1:self._n+1] = np.logical_and(new_facility_mask, new_tabu_mask)
obs_old_facility_mask = self._old_facility_mask
obs_new_facility_mask = self._new_facility_mask
if not np.any(obs_old_facility_mask) or not np.any(obs_new_facility_mask):
raise ValueError('The action mask is empty.')
return obs_old_facility_mask, obs_new_facility_mask
@staticmethod
def _pad_mask(mask: np.ndarray, max_num: int, name: Text) -> np.ndarray:
pad = (0, max_num - mask.size)
if pad[1] < 0:
raise ValueError(f'The number of {name} exceeds the maximum limit.')
return np.pad(mask, pad, mode='constant', constant_values=False)
def _pad_edge(self, edge: np.ndarray) -> np.ndarray:
virtual_edge = np.stack([np.zeros(self._n), np.arange(1, self._n + 1)], axis=-1).astype(np.int32)
edge = np.concatenate([virtual_edge, edge + 1], axis=0)
pad = ((0, self._edge_range - edge.shape[0]), (0, 0))
if pad[0][1] < 0:
raise ValueError('The number of edges exceeds the maximum limit.')
return np.pad(edge, pad, mode='constant', constant_values=self._node_range - 1)
def _pad_edge_wo_virtual(self, edge: np.ndarray) -> np.ndarray:
pad = ((0, self._edge_range - edge.shape[0]), (0, 0))
if pad[0][1] < 0:
print(self._edge_range, edge.shape[0])
raise ValueError('The number of edges exceeds the maximum limit.')
return np.pad(edge + 1, pad, mode='constant', constant_values=self._node_range - 1)