a-ragab-h-m commited on
Commit
8a08dc9
·
verified ·
1 Parent(s): b71417f

Update Actor/normalization.py

Browse files
Files changed (1) hide show
  1. Actor/normalization.py +45 -57
Actor/normalization.py CHANGED
@@ -2,87 +2,75 @@ import torch
2
 
3
 
4
  class Normalization(object):
5
-
6
  def __init__(self, actor, normalize_position=False, device='cpu'):
7
-
8
  self.normalize_position = normalize_position
9
  self.device = device
10
 
11
  graph = actor.graph
12
  fleet = actor.fleet
13
 
14
- batch_size = graph.distance_matrix.shape[0]
15
- num_nodes = graph.distance_matrix.shape[1]
16
- num_cars = fleet.start_time.shape[1]
17
 
18
- self.greatest_drive_time = graph.time_matrix.reshape(batch_size, -1).max(dim=1)[0]
19
- self.greatest_distance = graph.distance_matrix.reshape(batch_size, -1).max(dim=1)[0]
 
20
 
21
- a = fleet.start_time.reshape(batch_size, -1)
22
- b = graph.start_time.reshape(batch_size, -1)
23
- self.earliest_start_time = torch.cat([a, b], dim=1).min(dim=1)[0]
24
 
25
  self.mean_positions = graph.node_positions.mean(dim=1)
26
- self.std_positions = torch.std(graph.node_positions, dim=1)
27
-
28
 
29
  def normalize(self, actor):
 
 
 
30
 
31
- batch_size = actor.graph.distance_matrix.shape[0]
32
- num_nodes = actor.graph.distance_matrix.shape[1]
33
- num_cars = actor.fleet.start_time.shape[1]
34
-
35
- d = self.greatest_distance.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
36
- actor.graph.distance_matrix = actor.graph.distance_matrix / d
37
-
38
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
39
- actor.graph.time_matrix = actor.graph.time_matrix / t
40
 
41
- s = self.earliest_start_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
42
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
43
- actor.graph.start_time = (actor.graph.start_time - s) / t
44
- actor.graph.end_time = (actor.graph.end_time - s) / t
45
 
 
 
46
 
47
- t = self.greatest_drive_time.reshape(batch_size)
48
- actor.fleet.late_time = actor.fleet.late_time / t
49
-
50
- s = actor.fleet.arrival_times.shape
51
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, s[1], s[2])
52
- actor.fleet.arrival_times = actor.fleet.arrival_times / t
53
-
54
 
 
55
  if self.normalize_position:
56
- m = self.mean_positions.reshape(batch_size, 1, self.mean_positions.shape[-1]).repeat(1, num_nodes, 1)
57
- st = self.std_positions.reshape(batch_size, 1, self.std_positions.shape[-1]).repeat(1, num_nodes, 1)
58
- actor.graph.node_positions = (actor.graph.node_positions - m) / st
59
 
60
  def inverse_normalize(self, actor):
 
 
 
61
 
62
- batch_size = actor.graph.distance_matrix.shape[0]
63
- num_nodes = actor.graph.distance_matrix.shape[1]
64
- num_cars = actor.fleet.start_time.shape[1]
65
-
66
- d = self.greatest_distance.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
67
- actor.graph.distance_matrix = actor.graph.distance_matrix * d
68
-
69
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, num_nodes)
70
- actor.graph.time_matrix = actor.graph.time_matrix * t
71
-
72
- s = self.earliest_start_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
73
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, num_nodes, 1)
74
- actor.graph.start_time = actor.graph.start_time * t + s
75
- actor.graph.end_time = actor.graph.end_time * t + s
76
 
77
- t = self.greatest_drive_time.reshape(batch_size)
78
- actor.fleet.late_time = actor.fleet.late_time * t
 
79
 
80
- s = actor.fleet.arrival_times.shape
81
- t = self.greatest_drive_time.reshape(batch_size, 1, 1).repeat(1, s[1], s[2])
82
- actor.fleet.arrival_times = actor.fleet.arrival_times * t
83
 
 
 
 
84
 
 
85
  if self.normalize_position:
86
- m = self.mean_positions.reshape(batch_size, 1, self.mean_positions.shape[-1]).repeat(1, num_nodes, 1)
87
- st = self.std_positions.reshape(batch_size, 1, self.std_positions.shape[-1]).repeat(1, num_nodes, 1)
88
- actor.graph.node_positions = actor.graph.node_positions * st + m
 
2
 
3
 
4
  class Normalization(object):
 
5
  def __init__(self, actor, normalize_position=False, device='cpu'):
 
6
  self.normalize_position = normalize_position
7
  self.device = device
8
 
9
  graph = actor.graph
10
  fleet = actor.fleet
11
 
12
+ batch_size = graph.distance_matrix.size(0)
13
+ num_nodes = graph.distance_matrix.size(1)
 
14
 
15
+ # Normalize scale factors
16
+ self.greatest_drive_time = graph.time_matrix.view(batch_size, -1).max(dim=1)[0] # (B,)
17
+ self.greatest_distance = graph.distance_matrix.view(batch_size, -1).max(dim=1)[0]
18
 
19
+ fleet_start_flat = fleet.start_time.view(batch_size, -1)
20
+ graph_start_flat = graph.start_time.view(batch_size, -1)
21
+ self.earliest_start_time = torch.cat([fleet_start_flat, graph_start_flat], dim=1).min(dim=1)[0]
22
 
23
  self.mean_positions = graph.node_positions.mean(dim=1)
24
+ self.std_positions = graph.node_positions.std(dim=1)
 
25
 
26
  def normalize(self, actor):
27
+ batch_size = actor.graph.distance_matrix.size(0)
28
+ num_nodes = actor.graph.distance_matrix.size(1)
29
+ num_cars = actor.fleet.start_time.size(1)
30
 
31
+ # Normalize graph matrices
32
+ actor.graph.distance_matrix /= self.greatest_distance.view(batch_size, 1, 1)
33
+ actor.graph.time_matrix /= self.greatest_drive_time.view(batch_size, 1, 1)
 
 
 
 
 
 
34
 
35
+ # Normalize graph time windows
36
+ st_offset = self.earliest_start_time.view(batch_size, 1, 1)
37
+ st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
 
38
 
39
+ actor.graph.start_time = (actor.graph.start_time - st_offset) / st_scale
40
+ actor.graph.end_time = (actor.graph.end_time - st_offset) / st_scale
41
 
42
+ # Normalize fleet times
43
+ actor.fleet.late_time /= self.greatest_drive_time.view(batch_size, 1, 1)
44
+ actor.fleet.arrival_times /= self.greatest_drive_time.view(batch_size, 1, 1)
 
 
 
 
45
 
46
+ # Normalize positions (optional)
47
  if self.normalize_position:
48
+ mean_pos = self.mean_positions.view(batch_size, 1, -1)
49
+ std_pos = self.std_positions.view(batch_size, 1, -1)
50
+ actor.graph.node_positions = (actor.graph.node_positions - mean_pos) / std_pos
51
 
52
  def inverse_normalize(self, actor):
53
+ batch_size = actor.graph.distance_matrix.size(0)
54
+ num_nodes = actor.graph.distance_matrix.size(1)
55
+ num_cars = actor.fleet.start_time.size(1)
56
 
57
+ # Inverse graph matrices
58
+ actor.graph.distance_matrix *= self.greatest_distance.view(batch_size, 1, 1)
59
+ actor.graph.time_matrix *= self.greatest_drive_time.view(batch_size, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Inverse graph time windows
62
+ st_offset = self.earliest_start_time.view(batch_size, 1, 1)
63
+ st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
64
 
65
+ actor.graph.start_time = actor.graph.start_time * st_scale + st_offset
66
+ actor.graph.end_time = actor.graph.end_time * st_scale + st_offset
 
67
 
68
+ # Inverse fleet times
69
+ actor.fleet.late_time *= self.greatest_drive_time.view(batch_size, 1, 1)
70
+ actor.fleet.arrival_times *= self.greatest_drive_time.view(batch_size, 1, 1)
71
 
72
+ # Inverse normalization of positions
73
  if self.normalize_position:
74
+ mean_pos = self.mean_positions.view(batch_size, 1, -1)
75
+ std_pos = self.std_positions.view(batch_size, 1, -1)
76
+ actor.graph.node_positions = actor.graph.node_positions * std_pos + mean_pos