a-ragab-h-m commited on
Commit
b71417f
·
verified ·
1 Parent(s): f2809ce

Update Actor/fleet.py

Browse files
Files changed (1) hide show
  1. Actor/fleet.py +33 -37
Actor/fleet.py CHANGED
@@ -4,58 +4,54 @@ import torch
4
  class Fleet(object):
5
 
6
  def __init__(self, fleet_data, num_nodes, device='cpu'):
7
-
8
  self.device = device
9
-
10
  self.num_nodes = num_nodes
11
 
12
- #These fields are static
13
- self.start_time = fleet_data['start_time']
14
- self.car_start_node = fleet_data['car_start_node']
15
 
16
  self.batch_size = self.start_time.shape[0]
17
  self.num_cars = self.start_time.shape[1]
18
 
 
 
19
 
20
- #records the depot associated to each car
21
- self.depot = self.car_start_node.reshape(self.batch_size, self.num_cars).long()
22
-
23
- a = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, self.num_cars, 1).to(self.device)
24
- b = self.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
25
- self.num_depots = ((a == b).float().sum(dim=1) > 0).float().sum(dim=1).long()
26
 
27
- #These fields are dynamic. They will be updated as the tour is computed.
28
- self.time = self.start_time
29
- self.distance = torch.zeros(self.batch_size, self.num_cars, 1).to(self.device)
30
- self.late_time = torch.zeros(self.batch_size, self.num_cars, 1).to(self.device)
31
 
32
- #Path records the nodes that each car visited. Arrival times are the times when the car made it there.
33
- self.path = self.depot.unsqueeze(2)
34
- self.arrival_times = self.time
35
 
 
 
36
 
37
- #node is the current node of each car
38
- self.node = self.depot
39
-
40
- #traversed_nodes indicates which nodes have been visited
41
  self.traversed_nodes = self.initialize_traversed_nodes()
42
 
43
- #indicates whether or not the car has finished its route.
44
- self.finished = torch.zeros(self.batch_size, self.num_cars).to(self.device)
45
-
46
-
47
 
48
  def initialize_traversed_nodes(self):
49
-
50
- a = torch.arange(self.num_nodes).reshape(
51
- 1, -1, 1).repeat(
52
- self.batch_size, 1, self.num_cars).float().to(self.device)
53
-
54
- b = self.depot.reshape(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float()
55
- return ((a == b).float().sum(dim=2) > 0)
56
-
57
 
58
  def construct_vector(self):
59
-
60
- return self.time.reshape(self.batch_size, self.num_cars, 1)
61
-
 
 
4
  class Fleet(object):
5
 
6
  def __init__(self, fleet_data, num_nodes, device='cpu'):
 
7
  self.device = device
 
8
  self.num_nodes = num_nodes
9
 
10
+ # Static fields
11
+ self.start_time = fleet_data['start_time'].to(device)
12
+ self.car_start_node = fleet_data['car_start_node'].to(device)
13
 
14
  self.batch_size = self.start_time.shape[0]
15
  self.num_cars = self.start_time.shape[1]
16
 
17
+ # Depot assignment per car (long for indexing)
18
+ self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long()
19
 
20
+ # Count distinct depots per batch
21
+ node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1)
22
+ depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
23
+ self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long()
 
 
24
 
25
+ # Dynamic fields
26
+ self.time = self.start_time.clone()
27
+ self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
28
+ self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
29
 
30
+ # Path and arrival tracking
31
+ self.path = self.depot.unsqueeze(2) # (B, C, 1)
32
+ self.arrival_times = self.time.clone()
33
 
34
+ # Current location
35
+ self.node = self.depot.clone()
36
 
37
+ # Node visitation mask
 
 
 
38
  self.traversed_nodes = self.initialize_traversed_nodes()
39
 
40
+ # Termination flag
41
+ self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device)
 
 
42
 
43
  def initialize_traversed_nodes(self):
44
+ """
45
+ Initializes a boolean tensor indicating whether each node has been visited by each car.
46
+ Initially, only the depot node is visited.
47
+ """
48
+ node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float()
49
+ depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float()
50
+ visited = ((node_indices == depot_indices).float().sum(dim=2) > 0) # (B, N)
51
+ return visited
52
 
53
  def construct_vector(self):
54
+ """
55
+ Constructs the input vector used in the decoder (e.g., current time for each car).
56
+ """
57
+ return self.time.view(self.batch_size, self.num_cars, 1)