Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update Actor/fleet.py
Browse files- Actor/fleet.py +33 -37
    	
        Actor/fleet.py
    CHANGED
    
    | @@ -4,58 +4,54 @@ import torch | |
| 4 | 
             
            class Fleet(object):
         | 
| 5 |  | 
| 6 | 
             
                def __init__(self, fleet_data, num_nodes, device='cpu'):
         | 
| 7 | 
            -
             | 
| 8 | 
             
                    self.device = device
         | 
| 9 | 
            -
             | 
| 10 | 
             
                    self.num_nodes = num_nodes
         | 
| 11 |  | 
| 12 | 
            -
                    # | 
| 13 | 
            -
                    self.start_time = fleet_data['start_time']
         | 
| 14 | 
            -
                    self.car_start_node = fleet_data['car_start_node']
         | 
| 15 |  | 
| 16 | 
             
                    self.batch_size = self.start_time.shape[0]
         | 
| 17 | 
             
                    self.num_cars = self.start_time.shape[1]
         | 
| 18 |  | 
|  | |
|  | |
| 19 |  | 
| 20 | 
            -
                    # | 
| 21 | 
            -
                    self. | 
| 22 | 
            -
             | 
| 23 | 
            -
                     | 
| 24 | 
            -
                    b = self.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
         | 
| 25 | 
            -
                    self.num_depots = ((a == b).float().sum(dim=1) > 0).float().sum(dim=1).long()
         | 
| 26 |  | 
| 27 | 
            -
                    # | 
| 28 | 
            -
                    self.time = self.start_time
         | 
| 29 | 
            -
                    self.distance = torch.zeros(self.batch_size, self.num_cars, 1 | 
| 30 | 
            -
                    self.late_time = torch.zeros(self.batch_size, self.num_cars, 1 | 
| 31 |  | 
| 32 | 
            -
                    #Path  | 
| 33 | 
            -
                    self.path = self.depot.unsqueeze(2)
         | 
| 34 | 
            -
                    self.arrival_times = self.time
         | 
| 35 |  | 
|  | |
|  | |
| 36 |  | 
| 37 | 
            -
                    # | 
| 38 | 
            -
                    self.node = self.depot
         | 
| 39 | 
            -
             | 
| 40 | 
            -
                    #traversed_nodes indicates which nodes have been visited
         | 
| 41 | 
             
                    self.traversed_nodes = self.initialize_traversed_nodes()
         | 
| 42 |  | 
| 43 | 
            -
                    # | 
| 44 | 
            -
                    self.finished = torch.zeros(self.batch_size, self.num_cars | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 |  | 
| 48 | 
             
                def initialize_traversed_nodes(self):
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    a  | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
                     | 
| 55 | 
            -
                     | 
| 56 | 
            -
             | 
| 57 |  | 
| 58 | 
             
                def construct_vector(self):
         | 
| 59 | 
            -
             | 
| 60 | 
            -
                     | 
| 61 | 
            -
             | 
|  | 
|  | |
| 4 | 
             
            class Fleet(object):
         | 
| 5 |  | 
| 6 | 
             
                def __init__(self, fleet_data, num_nodes, device='cpu'):
         | 
|  | |
| 7 | 
             
                    self.device = device
         | 
|  | |
| 8 | 
             
                    self.num_nodes = num_nodes
         | 
| 9 |  | 
| 10 | 
            +
                    # Static fields
         | 
| 11 | 
            +
                    self.start_time = fleet_data['start_time'].to(device)
         | 
| 12 | 
            +
                    self.car_start_node = fleet_data['car_start_node'].to(device)
         | 
| 13 |  | 
| 14 | 
             
                    self.batch_size = self.start_time.shape[0]
         | 
| 15 | 
             
                    self.num_cars = self.start_time.shape[1]
         | 
| 16 |  | 
| 17 | 
            +
                    # Depot assignment per car (long for indexing)
         | 
| 18 | 
            +
                    self.depot = self.car_start_node.view(self.batch_size, self.num_cars).long()
         | 
| 19 |  | 
| 20 | 
            +
                    # Count distinct depots per batch
         | 
| 21 | 
            +
                    node_indices = torch.arange(self.num_nodes, device=self.device).view(1, 1, -1).repeat(self.batch_size, self.num_cars, 1)
         | 
| 22 | 
            +
                    depot_expanded = self.depot.view(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
         | 
| 23 | 
            +
                    self.num_depots = ((node_indices == depot_expanded).float().sum(dim=1) > 0).float().sum(dim=1).long()
         | 
|  | |
|  | |
| 24 |  | 
| 25 | 
            +
                    # Dynamic fields
         | 
| 26 | 
            +
                    self.time = self.start_time.clone()
         | 
| 27 | 
            +
                    self.distance = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
         | 
| 28 | 
            +
                    self.late_time = torch.zeros(self.batch_size, self.num_cars, 1, device=self.device)
         | 
| 29 |  | 
| 30 | 
            +
                    # Path and arrival tracking
         | 
| 31 | 
            +
                    self.path = self.depot.unsqueeze(2)  # (B, C, 1)
         | 
| 32 | 
            +
                    self.arrival_times = self.time.clone()
         | 
| 33 |  | 
| 34 | 
            +
                    # Current location
         | 
| 35 | 
            +
                    self.node = self.depot.clone()
         | 
| 36 |  | 
| 37 | 
            +
                    # Node visitation mask
         | 
|  | |
|  | |
|  | |
| 38 | 
             
                    self.traversed_nodes = self.initialize_traversed_nodes()
         | 
| 39 |  | 
| 40 | 
            +
                    # Termination flag
         | 
| 41 | 
            +
                    self.finished = torch.zeros(self.batch_size, self.num_cars, device=self.device)
         | 
|  | |
|  | |
| 42 |  | 
| 43 | 
             
                def initialize_traversed_nodes(self):
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    Initializes a boolean tensor indicating whether each node has been visited by each car.
         | 
| 46 | 
            +
                    Initially, only the depot node is visited.
         | 
| 47 | 
            +
                    """
         | 
| 48 | 
            +
                    node_indices = torch.arange(self.num_nodes, device=self.device).view(1, -1, 1).repeat(self.batch_size, 1, self.num_cars).float()
         | 
| 49 | 
            +
                    depot_indices = self.depot.view(self.batch_size, 1, self.num_cars).repeat(1, self.num_nodes, 1).float()
         | 
| 50 | 
            +
                    visited = ((node_indices == depot_indices).float().sum(dim=2) > 0)  # (B, N)
         | 
| 51 | 
            +
                    return visited
         | 
| 52 |  | 
| 53 | 
             
                def construct_vector(self):
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    Constructs the input vector used in the decoder (e.g., current time for each car).
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
                    return self.time.view(self.batch_size, self.num_cars, 1)
         |