Spaces:
Runtime error
Runtime error
Update Actor/normalization.py
Browse files- Actor/normalization.py +45 -57
Actor/normalization.py
CHANGED
|
@@ -2,87 +2,75 @@ import torch
|
|
| 2 |
|
| 3 |
|
| 4 |
class Normalization(object):
|
| 5 |
-
|
| 6 |
def __init__(self, actor, normalize_position=False, device='cpu'):
|
| 7 |
-
|
| 8 |
self.normalize_position = normalize_position
|
| 9 |
self.device = device
|
| 10 |
|
| 11 |
graph = actor.graph
|
| 12 |
fleet = actor.fleet
|
| 13 |
|
| 14 |
-
batch_size = graph.distance_matrix.
|
| 15 |
-
num_nodes = graph.distance_matrix.
|
| 16 |
-
num_cars = fleet.start_time.shape[1]
|
| 17 |
|
| 18 |
-
|
| 19 |
-
self.
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
self.earliest_start_time = torch.cat([
|
| 24 |
|
| 25 |
self.mean_positions = graph.node_positions.mean(dim=1)
|
| 26 |
-
self.std_positions =
|
| 27 |
-
|
| 28 |
|
| 29 |
def normalize(self, actor):
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
d = self.greatest_distance.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
|
| 36 |
-
actor.graph.distance_matrix = actor.graph.distance_matrix / d
|
| 37 |
-
|
| 38 |
-
t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
|
| 39 |
-
actor.graph.time_matrix = actor.graph.time_matrix / t
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
actor.graph.end_time = (actor.graph.end_time - s) / t
|
| 45 |
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
actor.fleet.late_time
|
| 49 |
-
|
| 50 |
-
s = actor.fleet.arrival_times.shape
|
| 51 |
-
t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, s[1], s[2])
|
| 52 |
-
actor.fleet.arrival_times = actor.fleet.arrival_times / t
|
| 53 |
-
|
| 54 |
|
|
|
|
| 55 |
if self.normalize_position:
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
actor.graph.node_positions = (actor.graph.node_positions -
|
| 59 |
|
| 60 |
def inverse_normalize(self, actor):
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
d = self.greatest_distance.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
|
| 67 |
-
actor.graph.distance_matrix = actor.graph.distance_matrix * d
|
| 68 |
-
|
| 69 |
-
t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
|
| 70 |
-
actor.graph.time_matrix = actor.graph.time_matrix * t
|
| 71 |
-
|
| 72 |
-
s = self.earliest_start_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
|
| 73 |
-
t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
|
| 74 |
-
actor.graph.start_time = actor.graph.start_time * t + s
|
| 75 |
-
actor.graph.end_time = actor.graph.end_time * t + s
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
actor.fleet.arrival_times = actor.fleet.arrival_times * t
|
| 83 |
|
|
|
|
|
|
|
|
|
|
| 84 |
|
|
|
|
| 85 |
if self.normalize_position:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
actor.graph.node_positions = actor.graph.node_positions *
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
class Normalization(object):
|
|
|
|
| 5 |
def __init__(self, actor, normalize_position=False, device='cpu'):
|
|
|
|
| 6 |
self.normalize_position = normalize_position
|
| 7 |
self.device = device
|
| 8 |
|
| 9 |
graph = actor.graph
|
| 10 |
fleet = actor.fleet
|
| 11 |
|
| 12 |
+
batch_size = graph.distance_matrix.size(0)
|
| 13 |
+
num_nodes = graph.distance_matrix.size(1)
|
|
|
|
| 14 |
|
| 15 |
+
# Normalize scale factors
|
| 16 |
+
self.greatest_drive_time = graph.time_matrix.view(batch_size, -1).max(dim=1)[0] # (B,)
|
| 17 |
+
self.greatest_distance = graph.distance_matrix.view(batch_size, -1).max(dim=1)[0]
|
| 18 |
|
| 19 |
+
fleet_start_flat = fleet.start_time.view(batch_size, -1)
|
| 20 |
+
graph_start_flat = graph.start_time.view(batch_size, -1)
|
| 21 |
+
self.earliest_start_time = torch.cat([fleet_start_flat, graph_start_flat], dim=1).min(dim=1)[0]
|
| 22 |
|
| 23 |
self.mean_positions = graph.node_positions.mean(dim=1)
|
| 24 |
+
self.std_positions = graph.node_positions.std(dim=1)
|
|
|
|
| 25 |
|
| 26 |
def normalize(self, actor):
|
| 27 |
+
batch_size = actor.graph.distance_matrix.size(0)
|
| 28 |
+
num_nodes = actor.graph.distance_matrix.size(1)
|
| 29 |
+
num_cars = actor.fleet.start_time.size(1)
|
| 30 |
|
| 31 |
+
# Normalize graph matrices
|
| 32 |
+
actor.graph.distance_matrix /= self.greatest_distance.view(batch_size, 1, 1)
|
| 33 |
+
actor.graph.time_matrix /= self.greatest_drive_time.view(batch_size, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# Normalize graph time windows
|
| 36 |
+
st_offset = self.earliest_start_time.view(batch_size, 1, 1)
|
| 37 |
+
st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
|
|
|
|
| 38 |
|
| 39 |
+
actor.graph.start_time = (actor.graph.start_time - st_offset) / st_scale
|
| 40 |
+
actor.graph.end_time = (actor.graph.end_time - st_offset) / st_scale
|
| 41 |
|
| 42 |
+
# Normalize fleet times
|
| 43 |
+
actor.fleet.late_time /= self.greatest_drive_time.view(batch_size, 1, 1)
|
| 44 |
+
actor.fleet.arrival_times /= self.greatest_drive_time.view(batch_size, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# Normalize positions (optional)
|
| 47 |
if self.normalize_position:
|
| 48 |
+
mean_pos = self.mean_positions.view(batch_size, 1, -1)
|
| 49 |
+
std_pos = self.std_positions.view(batch_size, 1, -1)
|
| 50 |
+
actor.graph.node_positions = (actor.graph.node_positions - mean_pos) / std_pos
|
| 51 |
|
| 52 |
def inverse_normalize(self, actor):
|
| 53 |
+
batch_size = actor.graph.distance_matrix.size(0)
|
| 54 |
+
num_nodes = actor.graph.distance_matrix.size(1)
|
| 55 |
+
num_cars = actor.fleet.start_time.size(1)
|
| 56 |
|
| 57 |
+
# Inverse graph matrices
|
| 58 |
+
actor.graph.distance_matrix *= self.greatest_distance.view(batch_size, 1, 1)
|
| 59 |
+
actor.graph.time_matrix *= self.greatest_drive_time.view(batch_size, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
# Inverse graph time windows
|
| 62 |
+
st_offset = self.earliest_start_time.view(batch_size, 1, 1)
|
| 63 |
+
st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
|
| 64 |
|
| 65 |
+
actor.graph.start_time = actor.graph.start_time * st_scale + st_offset
|
| 66 |
+
actor.graph.end_time = actor.graph.end_time * st_scale + st_offset
|
|
|
|
| 67 |
|
| 68 |
+
# Inverse fleet times
|
| 69 |
+
actor.fleet.late_time *= self.greatest_drive_time.view(batch_size, 1, 1)
|
| 70 |
+
actor.fleet.arrival_times *= self.greatest_drive_time.view(batch_size, 1, 1)
|
| 71 |
|
| 72 |
+
# Inverse normalization of positions
|
| 73 |
if self.normalize_position:
|
| 74 |
+
mean_pos = self.mean_positions.view(batch_size, 1, -1)
|
| 75 |
+
std_pos = self.std_positions.view(batch_size, 1, -1)
|
| 76 |
+
actor.graph.node_positions = actor.graph.node_positions * std_pos + mean_pos
|