先坤
commited on
Commit
•
db26c81
1
Parent(s):
49b4f2c
add greedrl
Browse files- .gitignore +20 -0
- CMakeLists.txt +68 -0
- README.md +625 -0
- csrc/common.h +184 -0
- csrc/pybind.cpp +11 -0
- csrc/task_group_priority.cpp +75 -0
- csrc/task_group_priority.cu +93 -0
- csrc/task_group_priority.h +30 -0
- csrc/task_group_split.cpp +69 -0
- csrc/task_group_split.cu +53 -0
- csrc/task_group_split.h +24 -0
- examples/batching/batching.py +165 -0
- examples/cvrp/cvrp.py +88 -0
- examples/cvrp/orts.py +107 -0
- examples/cvrp/solve.py +65 -0
- examples/cvrp/train.py +83 -0
- examples/cvrp/utils.py +65 -0
- examples/dpdp/dpdp.py +191 -0
- examples/pdptw/pdptw.py +136 -0
- examples/runner.py +38 -0
- examples/sdvrp/sdvrp.py +83 -0
- examples/tsp/tsp.py +74 -0
- examples/vrptw/vrptw.py +141 -0
- greedrl/.gitignore +2 -0
- greedrl/__init__.py +8 -0
- greedrl/agent.py +203 -0
- greedrl/const.py +7 -0
- greedrl/decode.py +196 -0
- greedrl/dense.py +31 -0
- greedrl/encode.py +349 -0
- greedrl/feature.py +63 -0
- greedrl/function.py +5 -0
- greedrl/norm.py +25 -0
- greedrl/pyenv.py +383 -0
- greedrl/solver.py +625 -0
- greedrl/utils.py +65 -0
- greedrl/variable.py +478 -0
- images/GREEDRL-Framwork.png +0 -0
- images/GREEDRL-Framwork_en.png +0 -0
- images/GREEDRL-Logo-Original-640.png +0 -0
- images/GREEDRL-Network.png +0 -0
- requirements.txt +7 -0
- setup.py +44 -0
- test/all_test.py +7 -0
- test/basetest.py +8 -0
- test/function_test.py +79 -0
- test/solver_test.py +40 -0
.gitignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
*.tar.gz
|
3 |
+
logs
|
4 |
+
**/__pycache__
|
5 |
+
data
|
6 |
+
*.log
|
7 |
+
*.pkl
|
8 |
+
*.pt
|
9 |
+
**/build/
|
10 |
+
**/dist/
|
11 |
+
**/*.egg-info
|
12 |
+
.DS_Store
|
13 |
+
.nfs*
|
14 |
+
*.so
|
15 |
+
*.dylib
|
16 |
+
*.iml
|
17 |
+
target
|
18 |
+
**/nohup.out
|
19 |
+
*.pth
|
20 |
+
**/.flattened-pom.xml
|
CMakeLists.txt
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 2.8.12)
|
2 |
+
project(greedrl_C_ LANGUAGES CXX)
|
3 |
+
|
4 |
+
set(CMAKE_CXX_STANDARD 14)
|
5 |
+
|
6 |
+
find_package(PythonInterp REQUIRED)
|
7 |
+
execute_process(COMMAND "python" "-c"
|
8 |
+
"
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
from distutils import sysconfig as s
|
12 |
+
print(s.get_python_inc(plat_specific=True))
|
13 |
+
print(s.get_config_var('EXT_SUFFIX'))
|
14 |
+
print(os.path.dirname(torch.__file__))
|
15 |
+
"
|
16 |
+
RESULT_VARIABLE _PYTHON_SUCCESS
|
17 |
+
OUTPUT_VARIABLE _PYTHON_VALUES
|
18 |
+
ERROR_VARIABLE _PYTHON_ERROR_VALUE)
|
19 |
+
|
20 |
+
if(NOT _PYTHON_SUCCESS MATCHES 0)
|
21 |
+
message("_PYTHON_SUCCESS: ${_PYTHON_SUCCESS}")
|
22 |
+
message("_PYTHON_VALUES: ${_PYTHON_VALUES}")
|
23 |
+
message("_PYTHON_ERROR_VALUE: ${_PYTHON_ERROR_VALUE}")
|
24 |
+
message(FATAL_ERROR "get python config error!")
|
25 |
+
endif()
|
26 |
+
|
27 |
+
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
|
28 |
+
list(GET _PYTHON_VALUES 0 PYTHON_INCLUDE_DIR)
|
29 |
+
list(GET _PYTHON_VALUES 1 PYTHON_EXT_SUFFIX)
|
30 |
+
list(GET _PYTHON_VALUES 2 TORCH_HOME)
|
31 |
+
|
32 |
+
include_directories(
|
33 |
+
${PYTHON_INCLUDE_DIR}
|
34 |
+
${TORCH_HOME}/include
|
35 |
+
${TORCH_HOME}/include/TH
|
36 |
+
${TORCH_HOME}/include/THC
|
37 |
+
${TORCH_HOME}/include/torch/csrc/api/include
|
38 |
+
)
|
39 |
+
|
40 |
+
string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_PATH_LENGTH)
|
41 |
+
add_compile_options(-DSOURCE_PATH_LENGTH=${SOURCE_PATH_LENGTH})
|
42 |
+
add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0 -fvisibility=hidden -fopenmp)
|
43 |
+
|
44 |
+
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
45 |
+
add_link_options(-undefined dynamic_lookup)
|
46 |
+
endif()
|
47 |
+
|
48 |
+
file(GLOB_RECURSE CSRC_CPP csrc/*.cpp)
|
49 |
+
|
50 |
+
add_library(greedrl_c MODULE ${CSRC_CPP})
|
51 |
+
set_target_properties(greedrl_c PROPERTIES PREFIX "")
|
52 |
+
set_target_properties(greedrl_c PROPERTIES SUFFIX "${PYTHON_EXT_SUFFIX}")
|
53 |
+
target_compile_options(greedrl_c PRIVATE -Wno-sign-conversion -O3)
|
54 |
+
target_link_libraries(greedrl_c c10 torch torch_cpu torch_python)
|
55 |
+
target_link_directories(greedrl_c PRIVATE ${TORCH_HOME}/lib)
|
56 |
+
|
57 |
+
find_package(CUDA)
|
58 |
+
if(CUDA_FOUND)
|
59 |
+
enable_language(CUDA)
|
60 |
+
file(GLOB_RECURSE CSRC_CU csrc/*.cu)
|
61 |
+
add_library(greedrl_cu OBJECT ${CSRC_CU})
|
62 |
+
target_compile_options(greedrl_cu PRIVATE -keep -Xptxas -v --expt-relaxed-constexpr --expt-extended-lambda -O3)
|
63 |
+
set_target_properties(greedrl_cu PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_ARCHITECTURES "70;75;80")
|
64 |
+
add_compile_definitions(CUDA_FOUND)
|
65 |
+
include_directories(${CUDA_INCLUDE_DIRS})
|
66 |
+
target_link_libraries(greedrl_c torch_cuda greedrl_cu)
|
67 |
+
target_link_directories(greedrl_c PRIVATE ${TORCH_HOME}/lib)
|
68 |
+
endif()
|
README.md
CHANGED
@@ -1,3 +1,628 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: reinforcement-learning
|
4 |
+
tags:
|
5 |
+
- Deep Reinforcement Learning
|
6 |
+
- Combinatorial Optimization
|
7 |
+
- Reinforcement Learning
|
8 |
+
- Vehicle Routing Problem
|
9 |
---
|
10 |
+
|
11 |
+
![](./images/GREEDRL-Logo-Original-640.png)
|
12 |
+
|
13 |
+
|
14 |
+
# 🤠GreedRL
|
15 |
+
|
16 |
+
## Overview
|
17 |
+
|
18 |
+
- 🤠GreedRL is a fast and general framework for **Combinatorial Optimization Problems (COPs)**, based on **Deep Reinforcement Learning (DRL)**.
|
19 |
+
|
20 |
+
- 🤠GreedRL achieves **1200 times faster and 3% improved performance** than [Google OR-Tools](https://developers.google.com/optimization) for large-scale (>=1000 nodes) CVRPs.
|
21 |
+
|
22 |
+
## 🏆Award
|
23 |
+
|
24 |
+
[INFORMS 2021 Franz Edelman Award finalists](https://www.informs.org/Resource-Center/Video-Library/Edelman-Competition-Videos/2021-Edelman-Competition-Videos/2021-Edelman-Finalist-Alibaba) for Achievement in Operations Research and the Management Sciences (recognized for our work on Cainiao Network VRP algorithm).
|
25 |
+
|
26 |
+
|
27 |
+
## Main features
|
28 |
+
|
29 |
+
* **GENERAL**
|
30 |
+
|
31 |
+
🤠GreedRL makes **a high level of abstraction for COPs**, which can solve various types of problems, such as TSP, CVRP, VRPTW, PDPTW, SDVRP, DPDP, Order Batching, etc.
|
32 |
+
|
33 |
+
* **HIGH-PERFORMANCE**
|
34 |
+
|
35 |
+
🤠GreedRL have improved the DRL environment (Env) simulation speed by **CUDA and C++ implementations**.
|
36 |
+
|
37 |
+
* **USER-FRIENDLY**
|
38 |
+
|
39 |
+
🤠GreedRL framework provides **user-friendly ability for COPs modeling**, where users only need to declare constraints, objectives and variables of COPs. For more examples, please refer to [COPs Modeling examples](https://huggingface.co/Cainiao-AI/GreedRL/blob/main/README.md#cops-modeling-examples).
|
40 |
+
|
41 |
+
## Editions
|
42 |
+
|
43 |
+
We provide an open source Community Edition and an Enterprise Edition of our 🤠GreedRL for users.
|
44 |
+
|
45 |
+
- **The Community Edition** is now released and available to [download](https://huggingface.co/Cainiao-AI/GreedRL).
|
46 |
+
- **The Enterprise Edition** has a high-performance implementation that achives a faster computing speed, especially when solving larg-scale COPs. For more informations, please contact <a href="mailto:jiangwen.wjw@alibaba-inc.com">us</a>.
|
47 |
+
|
48 |
+
|
49 |
+
## Architecture
|
50 |
+
![](./images/GREEDRL-Framwork_en.png)
|
51 |
+
|
52 |
+
## COPs Modeling examples
|
53 |
+
|
54 |
+
|
55 |
+
### Capacitated Vehicle Routing Problem (CVRP)
|
56 |
+
<details>
|
57 |
+
<summary>CVRP</summary>
|
58 |
+
|
59 |
+
```python
|
60 |
+
from greedrl.feature import *
|
61 |
+
from greedrl.variable import *
|
62 |
+
from greedrl.function import *
|
63 |
+
from greedrl import Problem, Solution, Solver
|
64 |
+
from greedrl import runner
|
65 |
+
|
66 |
+
features = [continuous_feature('task_demand'),
|
67 |
+
continuous_feature('worker_weight_limit'),
|
68 |
+
continuous_feature('distance_matrix'),
|
69 |
+
variable_feature('distance_this_to_task'),
|
70 |
+
variable_feature('distance_task_to_end')]
|
71 |
+
|
72 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
73 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
74 |
+
feature_variable('task_weight'),
|
75 |
+
worker_variable('worker_weight_limit'),
|
76 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
77 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
78 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
79 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
80 |
+
|
81 |
+
|
82 |
+
class Constraint:
|
83 |
+
|
84 |
+
def do_task(self):
|
85 |
+
return self.task_demand_this
|
86 |
+
|
87 |
+
def mask_task(self):
|
88 |
+
# 已经完成的任务
|
89 |
+
mask = self.task_demand_now <= 0
|
90 |
+
# 车辆容量限制
|
91 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
92 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
93 |
+
return mask
|
94 |
+
|
95 |
+
def finished(self):
|
96 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
97 |
+
|
98 |
+
|
99 |
+
class Objective:
|
100 |
+
|
101 |
+
def step_worker_end(self):
|
102 |
+
return self.distance_last_to_this
|
103 |
+
|
104 |
+
def step_task(self):
|
105 |
+
return self.distance_last_to_this
|
106 |
+
```
|
107 |
+
|
108 |
+
</details>
|
109 |
+
|
110 |
+
### Pickup and Delivery Problem with Time Windows (PDPTW)
|
111 |
+
<details>
|
112 |
+
<summary>PDPTW</summary>
|
113 |
+
|
114 |
+
```python
|
115 |
+
from greedrl.model import runner
|
116 |
+
from greedrl.feature import *
|
117 |
+
from greedrl.variable import *
|
118 |
+
from greedrl.function import *
|
119 |
+
from greedrl import Problem, Solution, Solver
|
120 |
+
|
121 |
+
features = [local_category('task_group'),
|
122 |
+
global_category('task_priority', 2),
|
123 |
+
variable_feature('distance_this_to_task'),
|
124 |
+
variable_feature('distance_task_to_end')]
|
125 |
+
|
126 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
127 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
128 |
+
feature_variable('task_weight'),
|
129 |
+
feature_variable('task_group'),
|
130 |
+
feature_variable('task_priority'),
|
131 |
+
feature_variable('task_due_time2', feature='task_due_time'),
|
132 |
+
task_variable('task_due_time'),
|
133 |
+
task_variable('task_service_time'),
|
134 |
+
task_variable('task_due_time_penalty'),
|
135 |
+
worker_variable('worker_basic_cost'),
|
136 |
+
worker_variable('worker_distance_cost'),
|
137 |
+
worker_variable('worker_due_time'),
|
138 |
+
worker_variable('worker_weight_limit'),
|
139 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
140 |
+
worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
|
141 |
+
'worker_ready_time'),
|
142 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
143 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
144 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
145 |
+
|
146 |
+
|
147 |
+
class Constraint:
|
148 |
+
|
149 |
+
def do_task(self):
|
150 |
+
return self.task_demand_this
|
151 |
+
|
152 |
+
def mask_worker_end(self):
|
153 |
+
return task_group_split(self.task_group, self.task_demand_now <= 0)
|
154 |
+
|
155 |
+
def mask_task(self):
|
156 |
+
mask = self.task_demand_now <= 0
|
157 |
+
mask |= task_group_priority(self.task_group, self.task_priority, mask)
|
158 |
+
|
159 |
+
worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
|
160 |
+
mask |= (worker_used_time > self.task_due_time2) & (self.task_priority == 0)
|
161 |
+
|
162 |
+
# 容量约束
|
163 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
164 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
165 |
+
return mask
|
166 |
+
|
167 |
+
def finished(self):
|
168 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
169 |
+
|
170 |
+
|
171 |
+
class Objective:
|
172 |
+
|
173 |
+
def step_worker_start(self):
|
174 |
+
return self.worker_basic_cost
|
175 |
+
|
176 |
+
def step_worker_end(self):
|
177 |
+
feasible = self.worker_used_time <= self.worker_due_time
|
178 |
+
return self.distance_last_to_this * self.worker_distance_cost, feasible
|
179 |
+
|
180 |
+
def step_task(self):
|
181 |
+
worker_used_time = self.worker_used_time - self.task_service_time
|
182 |
+
feasible = worker_used_time <= self.task_due_time
|
183 |
+
feasible &= worker_used_time <= self.worker_due_time
|
184 |
+
cost = self.distance_last_to_this * self.worker_distance_cost
|
185 |
+
return torch.where(feasible, cost, cost + self.task_due_time_penalty), feasible
|
186 |
+
```
|
187 |
+
|
188 |
+
</details>
|
189 |
+
|
190 |
+
|
191 |
+
### VRP with Time Windows (VRPTW)
|
192 |
+
<details>
|
193 |
+
<summary>VRPTW</summary>
|
194 |
+
|
195 |
+
```python
|
196 |
+
from greedrl import Problem, Solution, Solver
|
197 |
+
from greedrl.feature import *
|
198 |
+
from greedrl.variable import *
|
199 |
+
from greedrl.function import *
|
200 |
+
from greedrl.model import runner
|
201 |
+
from greedrl.myenv import VrptwEnv
|
202 |
+
|
203 |
+
features = [continuous_feature('worker_weight_limit'),
|
204 |
+
continuous_feature('worker_ready_time'),
|
205 |
+
continuous_feature('worker_due_time'),
|
206 |
+
continuous_feature('worker_basic_cost'),
|
207 |
+
continuous_feature('worker_distance_cost'),
|
208 |
+
continuous_feature('task_demand'),
|
209 |
+
continuous_feature('task_weight'),
|
210 |
+
continuous_feature('task_ready_time'),
|
211 |
+
continuous_feature('task_due_time'),
|
212 |
+
continuous_feature('task_service_time'),
|
213 |
+
continuous_feature('distance_matrix')]
|
214 |
+
|
215 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
216 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
217 |
+
feature_variable('task_weight'),
|
218 |
+
feature_variable('task_due_time'),
|
219 |
+
feature_variable('task_ready_time'),
|
220 |
+
feature_variable('task_service_time'),
|
221 |
+
worker_variable('worker_weight_limit'),
|
222 |
+
worker_variable('worker_due_time'),
|
223 |
+
worker_variable('worker_basic_cost'),
|
224 |
+
worker_variable('worker_distance_cost'),
|
225 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
226 |
+
worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
|
227 |
+
'worker_ready_time'),
|
228 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
229 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
230 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
231 |
+
|
232 |
+
|
233 |
+
class Constraint:
|
234 |
+
|
235 |
+
def do_task(self):
|
236 |
+
return self.task_demand_this
|
237 |
+
|
238 |
+
def mask_task(self):
|
239 |
+
# 已经完成的任务
|
240 |
+
mask = self.task_demand_now <= 0
|
241 |
+
# 车辆容量限制
|
242 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
243 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
244 |
+
|
245 |
+
worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
|
246 |
+
mask |= worker_used_time > self.task_due_time
|
247 |
+
|
248 |
+
worker_used_time = torch.max(worker_used_time, self.task_ready_time)
|
249 |
+
worker_used_time += self.task_service_time
|
250 |
+
worker_used_time += self.distance_task_to_end
|
251 |
+
mask |= worker_used_time > self.worker_due_time[:, None]
|
252 |
+
|
253 |
+
return mask
|
254 |
+
|
255 |
+
def finished(self):
|
256 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
257 |
+
|
258 |
+
|
259 |
+
class Objective:
|
260 |
+
|
261 |
+
def step_worker_start(self):
|
262 |
+
return self.worker_basic_cost
|
263 |
+
|
264 |
+
def step_worker_end(self):
|
265 |
+
return self.distance_last_to_this * self.worker_distance_cost
|
266 |
+
|
267 |
+
def step_task(self):
|
268 |
+
return self.distance_last_to_this * self.worker_distance_cost
|
269 |
+
```
|
270 |
+
|
271 |
+
</details>
|
272 |
+
|
273 |
+
### Travelling Salesman Problem (TSP)
|
274 |
+
<details>
|
275 |
+
<summary>TSP</summary>
|
276 |
+
|
277 |
+
```python
|
278 |
+
from greedrl.feature import *
|
279 |
+
from greedrl.variable import *
|
280 |
+
from greedrl import Problem
|
281 |
+
from greedrl import runner
|
282 |
+
|
283 |
+
features = [continuous_feature('task_location'),
|
284 |
+
variable_feature('distance_this_to_task'),
|
285 |
+
variable_feature('distance_task_to_end')]
|
286 |
+
|
287 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
288 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
289 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
290 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
291 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True),
|
292 |
+
edge_variable('distance_last_to_loop', feature='distance_matrix', last_to_loop=True)]
|
293 |
+
|
294 |
+
|
295 |
+
class Constraint:
|
296 |
+
|
297 |
+
def do_task(self):
|
298 |
+
return self.task_demand_this
|
299 |
+
|
300 |
+
def mask_task(self):
|
301 |
+
mask = self.task_demand_now <= 0
|
302 |
+
return mask
|
303 |
+
|
304 |
+
def mask_worker_end(self):
|
305 |
+
return torch.any(self.task_demand_now > 0, 1)
|
306 |
+
|
307 |
+
def finished(self):
|
308 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
309 |
+
|
310 |
+
|
311 |
+
class Objective:
|
312 |
+
|
313 |
+
def step_worker_end(self):
|
314 |
+
return self.distance_last_to_loop
|
315 |
+
|
316 |
+
def step_task(self):
|
317 |
+
return self.distance_last_to_this
|
318 |
+
```
|
319 |
+
|
320 |
+
</details>
|
321 |
+
|
322 |
+
### Split Delivery Vehicle Routing Problem (SDVRP)
|
323 |
+
<details>
|
324 |
+
<summary>SDVRP</summary>
|
325 |
+
|
326 |
+
```python
|
327 |
+
from greedrl.feature import *
|
328 |
+
from greedrl.variable import *
|
329 |
+
from greedrl import Problem
|
330 |
+
from greedrl import runner
|
331 |
+
|
332 |
+
features = [continuous_feature('task_demand'),
|
333 |
+
continuous_feature('worker_weight_limit'),
|
334 |
+
continuous_feature('distance_matrix'),
|
335 |
+
variable_feature('distance_this_to_task'),
|
336 |
+
variable_feature('distance_task_to_end')]
|
337 |
+
|
338 |
+
variables = [task_demand_now('task_demand'),
|
339 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
340 |
+
feature_variable('task_weight'),
|
341 |
+
task_variable('task_weight_this', feature='task_weight'),
|
342 |
+
worker_variable('worker_weight_limit'),
|
343 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
344 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True)]
|
345 |
+
|
346 |
+
|
347 |
+
class Constraint:
|
348 |
+
|
349 |
+
def do_task(self):
|
350 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
351 |
+
return torch.min(self.task_demand_this, worker_weight_limit // self.task_weight_this)
|
352 |
+
|
353 |
+
def mask_task(self):
|
354 |
+
mask = self.task_demand <= 0
|
355 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
356 |
+
mask |= self.task_weight > worker_weight_limit[:, None]
|
357 |
+
return mask
|
358 |
+
|
359 |
+
def finished(self):
|
360 |
+
return torch.all(self.task_demand <= 0, 1)
|
361 |
+
|
362 |
+
|
363 |
+
class Objective:
|
364 |
+
|
365 |
+
def step_worker_end(self):
|
366 |
+
return self.distance_last_to_this
|
367 |
+
|
368 |
+
def step_task(self):
|
369 |
+
return self.distance_last_to_this
|
370 |
+
```
|
371 |
+
|
372 |
+
</details>
|
373 |
+
|
374 |
+
### Realistic Business Scenario
|
375 |
+
<details>
|
376 |
+
<summary>real-time Dynamic Pickup and Delivery Problem (DPDP)</summary>
|
377 |
+
|
378 |
+
```python
|
379 |
+
from greedrl.feature import *
|
380 |
+
from greedrl.variable import *
|
381 |
+
from greedrl.function import *
|
382 |
+
from greedrl import Problem
|
383 |
+
from greedrl import runner
|
384 |
+
|
385 |
+
features = [local_category('task_order'),
|
386 |
+
global_category('task_type', 2),
|
387 |
+
global_category('task_new_order', 2),
|
388 |
+
variable_feature('time_this_to_task'),
|
389 |
+
continuous_feature('x_time_matrix'),
|
390 |
+
continuous_feature('task_due_time_x'),
|
391 |
+
continuous_feature('worker_task_mask')]
|
392 |
+
|
393 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
394 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
395 |
+
task_variable('task_pickup_this', feature='task_pickup'),
|
396 |
+
task_variable('task_due_time_this', feature='task_due_time'),
|
397 |
+
feature_variable('task_order', feature='task_order'),
|
398 |
+
feature_variable('task_type', feature='task_type'),
|
399 |
+
feature_variable('task_new_pickup', feature='task_new_pickup'),
|
400 |
+
feature_variable('worker_task_mask', feature='worker_task_mask'),
|
401 |
+
worker_count_now('worker_count_now', feature='worker_count'),
|
402 |
+
worker_variable('worker_min_old_task_this', feature='worker_min_old_task'),
|
403 |
+
worker_variable('worker_max_new_order_this', feature='worker_max_new_order'),
|
404 |
+
worker_variable('worker_task_mask_this', feature='worker_task_mask'),
|
405 |
+
worker_used_resource('worker_used_old_task', task_require='task_old'),
|
406 |
+
worker_used_resource('worker_used_new_order', task_require='task_new_pickup'),
|
407 |
+
worker_used_resource('worker_used_time', edge_require='time_matrix'),
|
408 |
+
edge_variable('time_this_to_task', feature='x_time_matrix', this_to_task=True)]
|
409 |
+
|
410 |
+
|
411 |
+
class Constraint:
|
412 |
+
|
413 |
+
def do_task(self):
|
414 |
+
return self.task_demand_this
|
415 |
+
|
416 |
+
def mask_worker_start(self):
|
417 |
+
mask = self.worker_count_now <= 0
|
418 |
+
|
419 |
+
finished = self.task_demand_now <= 0
|
420 |
+
worker_task_mask = self.worker_task_mask | finished[:, None, :]
|
421 |
+
mask |= torch.all(worker_task_mask, 2)
|
422 |
+
|
423 |
+
return mask
|
424 |
+
|
425 |
+
def mask_worker_end(self):
|
426 |
+
mask = self.worker_used_old_task < self.worker_min_old_task_this
|
427 |
+
mask |= task_group_split(self.task_order, self.task_demand_now <= 0)
|
428 |
+
return mask
|
429 |
+
|
430 |
+
def mask_task(self):
|
431 |
+
mask = self.task_demand_now <= 0
|
432 |
+
|
433 |
+
mask |= task_group_priority(self.task_order, self.task_type, mask)
|
434 |
+
|
435 |
+
worker_max_new_order = self.worker_max_new_order_this - self.worker_used_new_order
|
436 |
+
mask |= self.task_new_pickup > worker_max_new_order[:, None]
|
437 |
+
|
438 |
+
mask |= self.worker_task_mask_this
|
439 |
+
|
440 |
+
return mask
|
441 |
+
|
442 |
+
def finished(self):
|
443 |
+
worker_mask = self.worker_count_now <= 0
|
444 |
+
task_mask = self.task_demand_now <= 0
|
445 |
+
worker_task_mask = worker_mask[:, :, None] | task_mask[:, None, :]
|
446 |
+
|
447 |
+
worker_task_mask |= self.worker_task_mask
|
448 |
+
batch_size = worker_task_mask.size(0)
|
449 |
+
worker_task_mask = worker_task_mask.view(batch_size, -1)
|
450 |
+
return worker_task_mask.all(1)
|
451 |
+
|
452 |
+
|
453 |
+
class Objective:
|
454 |
+
|
455 |
+
def step_task(self):
|
456 |
+
over_time = (self.worker_used_time - self.task_due_time_this).clamp(min=0)
|
457 |
+
pickup_time = self.worker_used_time * self.task_pickup_this
|
458 |
+
return self.worker_used_time + over_time + pickup_time
|
459 |
+
|
460 |
+
def step_finish(self):
|
461 |
+
return self.task_demand_now.sum(1) * 1000
|
462 |
+
```
|
463 |
+
|
464 |
+
</details>
|
465 |
+
|
466 |
+
### Order Batching Problem
|
467 |
+
<details>
|
468 |
+
<summary>Batching</summary>
|
469 |
+
|
470 |
+
```python
|
471 |
+
from greedrl import Problem, Solver
|
472 |
+
from greedrl.feature import *
|
473 |
+
from greedrl.variable import *
|
474 |
+
from greedrl import runner
|
475 |
+
|
476 |
+
|
477 |
+
features = [local_feature('task_area'),
|
478 |
+
local_feature('task_roadway'),
|
479 |
+
local_feature('task_area_group'),
|
480 |
+
sparse_local_feature('task_item_id', 'task_item_num'),
|
481 |
+
sparse_local_feature('task_item_owner_id', 'task_item_num'),
|
482 |
+
variable_feature('worker_task_item'),
|
483 |
+
variable_feature('worker_used_roadway'),
|
484 |
+
variable_feature('worker_used_area')]
|
485 |
+
|
486 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
487 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
488 |
+
feature_variable('task_item_id'),
|
489 |
+
feature_variable('task_item_num'),
|
490 |
+
feature_variable('task_item_owner_id'),
|
491 |
+
feature_variable('task_area'),
|
492 |
+
feature_variable('task_area_group'),
|
493 |
+
feature_variable('task_load'),
|
494 |
+
feature_variable('task_group'),
|
495 |
+
worker_variable('worker_load_limit'),
|
496 |
+
worker_variable('worker_area_limit'),
|
497 |
+
worker_variable('worker_area_group_limit'),
|
498 |
+
worker_task_item('worker_task_item', item_id='task_item_id', item_num='task_item_num'),
|
499 |
+
worker_task_item('worker_task_item_owner', item_id='task_item_owner_id', item_num='task_item_num'),
|
500 |
+
worker_used_resource('worker_used_load', task_require='task_load'),
|
501 |
+
worker_used_resource('worker_used_area', task_require='task_area'),
|
502 |
+
worker_used_resource('worker_used_roadway', task_require='task_roadway'),
|
503 |
+
worker_used_resource('worker_used_area_group', task_require='task_area_group')]
|
504 |
+
|
505 |
+
|
506 |
+
class Constraint:
|
507 |
+
|
508 |
+
def do_task(self):
|
509 |
+
return self.task_demand_this
|
510 |
+
|
511 |
+
def mask_worker_end(self):
|
512 |
+
return self.worker_used_load < self.worker_load_limit
|
513 |
+
|
514 |
+
def mask_task(self):
|
515 |
+
# completed tasks
|
516 |
+
mask = self.task_demand_now <= 0
|
517 |
+
# mask |= task_group_priority(self.task_group, self.task_out_stock_time, mask)
|
518 |
+
|
519 |
+
NT = self.task_item_id.size(1)
|
520 |
+
worker_task_item = self.worker_task_item[:, None, :]
|
521 |
+
worker_task_item = worker_task_item.expand(-1, NT, -1)
|
522 |
+
task_item_in_worker = worker_task_item.gather(2, self.task_item_id.long())
|
523 |
+
task_item_in_worker = (task_item_in_worker > 0) & (self.task_item_num > 0)
|
524 |
+
|
525 |
+
worker_task_item_owner = self.worker_task_item_owner[:, None, :]
|
526 |
+
worker_task_item_owner = worker_task_item_owner.expand(-1, NT, -1)
|
527 |
+
task_item_owner_in_worker = worker_task_item_owner.gather(2, self.task_item_owner_id.long())
|
528 |
+
task_item_owner_in_worker = (task_item_owner_in_worker > 0) & (self.task_item_num > 0)
|
529 |
+
|
530 |
+
#
|
531 |
+
mask |= torch.any(task_item_in_worker & ~task_item_owner_in_worker, 2)
|
532 |
+
|
533 |
+
worker_load_limit = self.worker_load_limit - self.worker_used_load
|
534 |
+
mask |= (self.task_load > worker_load_limit[:, None])
|
535 |
+
|
536 |
+
task_area = self.task_area + self.worker_used_area[:, None, :]
|
537 |
+
task_area_num = task_area.clamp(0, 1).sum(2, dtype=torch.int32)
|
538 |
+
mask |= (task_area_num > self.worker_area_limit[:, None])
|
539 |
+
|
540 |
+
tak_area_group = self.task_area_group + self.worker_used_area_group[:, None, :]
|
541 |
+
tak_area_group_num = tak_area_group.clamp(0, 1).sum(2, dtype=torch.int32)
|
542 |
+
mask |= (tak_area_group_num > self.worker_area_group_limit[:, None])
|
543 |
+
|
544 |
+
return mask
|
545 |
+
|
546 |
+
def finished(self):
|
547 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
548 |
+
|
549 |
+
|
550 |
+
class Objective:
|
551 |
+
|
552 |
+
def step_worker_end(self):
|
553 |
+
area_num = self.worker_used_area.clamp(0, 1).sum(1)
|
554 |
+
roadway_num = self.worker_used_roadway.clamp(0, 1).sum(1)
|
555 |
+
item_num = self.worker_task_item.clamp(0, 1).sum(1)
|
556 |
+
penalty = (self.worker_load_limit - self.worker_used_load) * 10
|
557 |
+
return area_num * 100 + roadway_num * 10 + item_num + penalty
|
558 |
+
```
|
559 |
+
|
560 |
+
</details>
|
561 |
+
|
562 |
+
|
563 |
+
#
|
564 |
+
#
|
565 |
+
# Getting started
|
566 |
+
|
567 |
+
## Description
|
568 |
+
We are delighted to release 🤠GreedRL Community Edition, as well as example of training and testing scripts for the standard Capacitated VRP (CVRP), you can download it and get started.
|
569 |
+
|
570 |
+
## Test environment
|
571 |
+
🤠GreedRL Community Edition has been tested on Ubuntu 18.04 with GCC compiler v7.5.0 and CUDA version 11.4, and a [Miniconda](https://docs.conda.io/en/latest/miniconda.html#system-requirements) distribution with Python 3.8. We recommend using a similar configuration to avoid any possiblem compilation issue.
|
572 |
+
|
573 |
+
## Installation
|
574 |
+
First, clone the repository.
|
575 |
+
```aidl
|
576 |
+
$ git clone https://huggingface.co/Cainiao-AI/GreedRL
|
577 |
+
```
|
578 |
+
Then, create and activate a python environment using conda, and install required packages.
|
579 |
+
```aidl
|
580 |
+
$ conda create -n python38 python==3.8
|
581 |
+
$ source activate python38
|
582 |
+
$ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113
|
583 |
+
```
|
584 |
+
Finally, compile and add the resulting library `greedrl` to the `PYTHONPATH`
|
585 |
+
```aidl
|
586 |
+
$ python setup.py build
|
587 |
+
$ export PYTHONPATH={your_current_path}/build/lib.linux-x86_64-cpython-38/:$PYTHONPATH
|
588 |
+
```
|
589 |
+
|
590 |
+
## CVRP Training
|
591 |
+
|
592 |
+
1. Training data
|
593 |
+
|
594 |
+
We use generated data for the training phase, the customers and depot locations are randomly generated in the unit square [0,1] X [0,1]. For CVRP, we assume that the demand of each node is a discrete number in {1,...,9}, chosen uniformly at random, and each vehicle has a default capacity of 50.
|
595 |
+
|
596 |
+
|
597 |
+
2. Start training
|
598 |
+
```python
|
599 |
+
$ cd examples/cvrp
|
600 |
+
$ python train.py --model_filename cvrp_100.pt --problem_size 100
|
601 |
+
```
|
602 |
+
|
603 |
+
## CVRP Testing
|
604 |
+
|
605 |
+
After training process, you'll get a trained model, like `cvrp_100.pt`, that you can use for test.
|
606 |
+
|
607 |
+
```python
|
608 |
+
$ cd examples/cvrp
|
609 |
+
$ python solve.py --device cpu --model_name cvrp_100.pt --problem_size 100
|
610 |
+
```
|
611 |
+
|
612 |
+
# Support
|
613 |
+
We look forward you to downloading it, using it, and opening discussion if you encounter any problems or have ideas on building an even better experience.
|
614 |
+
For commercial enquiries, please contact <a href="mailto:jiangwen.wjw@alibaba-inc.com">us</a>.
|
615 |
+
|
616 |
+
# Citation
|
617 |
+
```
|
618 |
+
@article{hu2022alibaba,
|
619 |
+
title={Alibaba vehicle routing algorithms enable rapid pick and delivery},
|
620 |
+
author={Hu, Haoyuan and Zhang, Ying and Wei, Jiangwen and Zhan, Yang and Zhang, Xinhui and Huang, Shaojian and Ma, Guangrui and Deng, Yuming and Jiang, Siwei},
|
621 |
+
journal={INFORMS Journal on Applied Analytics},
|
622 |
+
volume={52},
|
623 |
+
number={1},
|
624 |
+
pages={27--41},
|
625 |
+
year={2022},
|
626 |
+
publisher={INFORMS}
|
627 |
+
}
|
628 |
+
```
|
csrc/common.h
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <cfloat>
|
3 |
+
#include <climits>
|
4 |
+
#include <cstdint>
|
5 |
+
#include <limits>
|
6 |
+
#include <chrono>
|
7 |
+
#include <stdexcept>
|
8 |
+
#include <torch/extension.h>
|
9 |
+
|
10 |
+
#define ASSERT(c) assert(c)
|
11 |
+
#define ALIGN(v, n) ((v + n - 1) / n * n)
|
12 |
+
#define INF std::numeric_limits<float>::infinity()
|
13 |
+
#define __FILENAME__ (__FILE__+ SOURCE_PATH_LENGTH)
|
14 |
+
|
15 |
+
#define GRL_ERROR(format, args...) \
|
16 |
+
greedrl_error(__FILENAME__, __LINE__, format, ##args); \
|
17 |
+
|
18 |
+
|
19 |
+
#define GRL_CHECK(flag, format, args...) \
|
20 |
+
greedrl_check(__FILENAME__, __LINE__, flag, format, ##args); \
|
21 |
+
|
22 |
+
|
23 |
+
#define MALLOC(ptr, T, size) \
|
24 |
+
ptr = (T*) malloc(sizeof(T) * (size)); \
|
25 |
+
GRL_CHECK(ptr != nullptr, "out of memory!"); \
|
26 |
+
|
27 |
+
|
28 |
+
#define GALLOC(ptr, T, size) \
|
29 |
+
GRL_CHECK((size) > 0, "malloc 0 bytes"); \
|
30 |
+
T* const ptr = (T*) malloc(sizeof(T) * (size)); \
|
31 |
+
GRL_CHECK(ptr != nullptr, "out of memory!"); \
|
32 |
+
AllocGuard ptr##_##alloc##_##guard(ptr); \
|
33 |
+
|
34 |
+
|
35 |
+
#define REALLOC(ptr, T, size) \
|
36 |
+
GRL_CHECK((size) > 0, "malloc 0 bytes"); \
|
37 |
+
ptr = (T*) realloc(ptr, sizeof(T) * (size)); \
|
38 |
+
GRL_CHECK(ptr != nullptr, "out of memory!"); \
|
39 |
+
|
40 |
+
|
41 |
+
#define GRL_CHECK_TENSOR(tensor, device, allow_sub_contiguous, allow_null, ...) \
|
42 |
+
greedrl_check_tensor(__FILENAME__, __LINE__, tensor, #tensor, device, \
|
43 |
+
allow_sub_contiguous, allow_null, {__VA_ARGS__}); \
|
44 |
+
|
45 |
+
|
46 |
+
const int GRL_WORKER_START = 0;
|
47 |
+
const int GRL_WORKER_END = 1;
|
48 |
+
const int GRL_TASK = 2;
|
49 |
+
const int GRL_FINISH = 3;
|
50 |
+
|
51 |
+
const int MAX_BATCH_SIZE = 100000;
|
52 |
+
const int MAX_TASK_COUNT = 5120;
|
53 |
+
const int MAX_SHARED_MEM = 48128;
|
54 |
+
|
55 |
+
using String = std::string;
|
56 |
+
using Device = torch::Device;
|
57 |
+
using Tensor = torch::Tensor;
|
58 |
+
using TensorMap = std::map<String, Tensor>;
|
59 |
+
using TensorList = std::vector<Tensor>;
|
60 |
+
|
61 |
+
|
62 |
+
inline void greedrl_error(const char* const file, const int64_t line,
|
63 |
+
const char* const format, ...)
|
64 |
+
{
|
65 |
+
const int N = 2048;
|
66 |
+
static char buf[N];
|
67 |
+
|
68 |
+
va_list args;
|
69 |
+
va_start(args, format);
|
70 |
+
int n = vsnprintf(buf, N, format, args);
|
71 |
+
va_end(args);
|
72 |
+
|
73 |
+
if(n < N)
|
74 |
+
{
|
75 |
+
snprintf(buf+n, N-n, " at %s:%ld", file, line);
|
76 |
+
}
|
77 |
+
|
78 |
+
throw std::runtime_error(buf);
|
79 |
+
}
|
80 |
+
|
81 |
+
inline void greedrl_check(const char* const file, const int64_t line,
|
82 |
+
const bool flag, const char* const format, ...)
|
83 |
+
{
|
84 |
+
if(flag)
|
85 |
+
{
|
86 |
+
return;
|
87 |
+
}
|
88 |
+
|
89 |
+
const int N = 2048;
|
90 |
+
static char buf[N];
|
91 |
+
|
92 |
+
va_list args;
|
93 |
+
va_start(args, format);
|
94 |
+
int n = vsnprintf(buf, N, format, args);
|
95 |
+
va_end(args);
|
96 |
+
|
97 |
+
if(n < N)
|
98 |
+
{
|
99 |
+
snprintf(buf+n, N-n, " at %s:%ld", file, line);
|
100 |
+
}
|
101 |
+
|
102 |
+
throw std::runtime_error(buf);
|
103 |
+
}
|
104 |
+
|
105 |
+
// contiguous except the 1st dimension
|
106 |
+
inline bool is_sub_contiguous(const Tensor& tensor)
|
107 |
+
{
|
108 |
+
int dim = tensor.dim();
|
109 |
+
if(dim==1) return true;
|
110 |
+
|
111 |
+
auto sizes = tensor.sizes();
|
112 |
+
auto strides = tensor.strides();
|
113 |
+
|
114 |
+
if(strides[dim-1] != 1) return false;
|
115 |
+
|
116 |
+
int s = 1;
|
117 |
+
for(int i=dim-2; i>0; i--)
|
118 |
+
{
|
119 |
+
s *= sizes[i+1];
|
120 |
+
if(strides[i] != s) return false;
|
121 |
+
}
|
122 |
+
|
123 |
+
return true;
|
124 |
+
|
125 |
+
};
|
126 |
+
|
127 |
+
inline void greedrl_check_tensor(const char* const file,
|
128 |
+
const int line,
|
129 |
+
const Tensor& tensor,
|
130 |
+
const String& name,
|
131 |
+
const Device& device,
|
132 |
+
bool allow_sub_contiguous,
|
133 |
+
bool allow_null,
|
134 |
+
std::initializer_list<int> sizes)
|
135 |
+
{
|
136 |
+
greedrl_check(file, line, tensor.numel() < 1000 * 1000 * 1000, "tensor size too large");
|
137 |
+
|
138 |
+
auto device2 = tensor.device();
|
139 |
+
greedrl_check(file, line, device2==device,
|
140 |
+
"'%s' device is %s, but expect %s",
|
141 |
+
name.c_str(), device2.str().c_str(), device.str().c_str());
|
142 |
+
|
143 |
+
bool is_contiguous = allow_sub_contiguous ? is_sub_contiguous(tensor) : tensor.is_contiguous();
|
144 |
+
greedrl_check(file, line, is_contiguous, "'%s' is not contiguous", name.c_str());
|
145 |
+
|
146 |
+
if(allow_null && tensor.data_ptr() == nullptr) return;
|
147 |
+
|
148 |
+
if(tensor.dim() != sizes.size())
|
149 |
+
{
|
150 |
+
greedrl_error(file, line, "'%s' dim is %d, but expect %d", name.c_str(), (int)tensor.dim(), (int)sizes.size());
|
151 |
+
}
|
152 |
+
int i=0;
|
153 |
+
for(auto s:sizes)
|
154 |
+
{
|
155 |
+
greedrl_check(file, line, tensor.size(i)==s, "'%s' size(%d) is %d, but expect %d", name.c_str(), i, (int)tensor.size(i), s);
|
156 |
+
i++;
|
157 |
+
}
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
#ifdef CUDA_FOUND
|
162 |
+
|
163 |
+
#include <cuda_runtime_api.h>
|
164 |
+
|
165 |
+
#define GRL_CHECK_CUDA(error)\
|
166 |
+
greedrl_check_cuda(error, __FILENAME__, __LINE__);
|
167 |
+
|
168 |
+
inline void greedrl_check_cuda(const cudaError_t& error,
|
169 |
+
const char* file, const int64_t line)
|
170 |
+
{
|
171 |
+
if(error==cudaSuccess)
|
172 |
+
{
|
173 |
+
return;
|
174 |
+
}
|
175 |
+
|
176 |
+
const int N = 2048;
|
177 |
+
static char buf[N];
|
178 |
+
snprintf(buf, N, "%s, at %s:%ld", cudaGetErrorString(error), file, line);
|
179 |
+
throw std::runtime_error(buf);
|
180 |
+
}
|
181 |
+
|
182 |
+
cudaDeviceProp& cuda_get_device_prop(int i);
|
183 |
+
|
184 |
+
#endif
|
csrc/pybind.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <pybind11/pybind11.h>
|
2 |
+
#include "task_group_split.h"
|
3 |
+
#include "task_group_priority.h"
|
4 |
+
|
5 |
+
namespace py = pybind11;
|
6 |
+
|
7 |
+
PYBIND11_MODULE(greedrl_c, m) {
|
8 |
+
m.def("task_group_split", &task_group_split);
|
9 |
+
m.def("task_group_priority", &task_group_priority);
|
10 |
+
}
|
11 |
+
|
csrc/task_group_priority.cpp
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "task_group_priority.h"
|
2 |
+
|
3 |
+
void task_group_priority_cpu(
|
4 |
+
int* group, int* priority, bool* value, bool* output,
|
5 |
+
int batch_size, int task_num, int group_num)
|
6 |
+
{
|
7 |
+
auto temp = torch::make_unique<int[]>(group_num);
|
8 |
+
for(int b=0; b<batch_size; b++)
|
9 |
+
{
|
10 |
+
for(int i=0; i<group_num; i++){
|
11 |
+
temp[i] = std::numeric_limits<int>::max();
|
12 |
+
}
|
13 |
+
|
14 |
+
for(int i=0; i<task_num; i++){
|
15 |
+
if(value[i]){
|
16 |
+
continue;
|
17 |
+
}
|
18 |
+
int g = group[i];
|
19 |
+
int p = priority[i];
|
20 |
+
if(p < temp[g]){
|
21 |
+
temp[g] = p;
|
22 |
+
}
|
23 |
+
}
|
24 |
+
|
25 |
+
for(int i=0; i<task_num; i++){
|
26 |
+
int g = group[i];
|
27 |
+
output[i] = priority[i]!=temp[g];
|
28 |
+
}
|
29 |
+
|
30 |
+
group += task_num;
|
31 |
+
priority += task_num;
|
32 |
+
value += task_num;
|
33 |
+
output += task_num;
|
34 |
+
}
|
35 |
+
};
|
36 |
+
|
37 |
+
auto task_group_priority(
|
38 |
+
const torch::Tensor& group,
|
39 |
+
const torch::Tensor& priority,
|
40 |
+
const torch::Tensor& value) -> torch::Tensor
|
41 |
+
{
|
42 |
+
auto device = group.device();
|
43 |
+
|
44 |
+
const int batch_size = group.size(0);
|
45 |
+
const int task_num = group.size(1);
|
46 |
+
const int group_num = group.max().item<int>() + 1;
|
47 |
+
|
48 |
+
const int _group_num = group.min().item<int>();
|
49 |
+
|
50 |
+
GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
|
51 |
+
|
52 |
+
GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
|
53 |
+
GRL_CHECK_TENSOR(priority, device, false, false, batch_size, task_num);
|
54 |
+
GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
|
55 |
+
|
56 |
+
auto output = torch::zeros({batch_size, task_num}, torch::dtype(torch::kBool).device(device));
|
57 |
+
|
58 |
+
switch(device.type())
|
59 |
+
{
|
60 |
+
case torch::kCPU:
|
61 |
+
task_group_priority_cpu(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
|
62 |
+
output.data_ptr<bool>(), batch_size, task_num, group_num);
|
63 |
+
break;
|
64 |
+
#ifdef CUDA_FOUND
|
65 |
+
case torch::kCUDA:
|
66 |
+
task_group_priority_cuda(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
|
67 |
+
output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
|
68 |
+
break;
|
69 |
+
#endif
|
70 |
+
default:
|
71 |
+
GRL_ERROR("unsupported device: %s", device.str().c_str());
|
72 |
+
}
|
73 |
+
|
74 |
+
return output;
|
75 |
+
};
|
csrc/task_group_priority.cu
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "task_group_priority.h"
|
2 |
+
|
3 |
+
__global__ void task_group_priority_kernel(
|
4 |
+
int* group, int* priority, bool* value, bool* output,
|
5 |
+
int batch_size, int task_num, int group_num)
|
6 |
+
{
|
7 |
+
group += blockIdx.x * task_num;
|
8 |
+
priority += blockIdx.x * task_num;
|
9 |
+
value += blockIdx.x * task_num;
|
10 |
+
output += blockIdx.x * task_num;
|
11 |
+
|
12 |
+
extern __shared__ int temp[];
|
13 |
+
|
14 |
+
for(int i=threadIdx.x; i<group_num; i+=blockDim.x)
|
15 |
+
{
|
16 |
+
temp[i] = std::numeric_limits<int>::max();
|
17 |
+
}
|
18 |
+
|
19 |
+
__syncthreads();
|
20 |
+
|
21 |
+
for(int i=threadIdx.x; i<task_num; i+=blockDim.x){
|
22 |
+
if(value[i]){
|
23 |
+
continue;
|
24 |
+
}
|
25 |
+
int g = group[i];
|
26 |
+
int p = priority[i];
|
27 |
+
atomicMin(&temp[g], p);
|
28 |
+
}
|
29 |
+
|
30 |
+
__syncthreads();
|
31 |
+
|
32 |
+
for(int i=threadIdx.x; i<task_num; i+=blockDim.x){
|
33 |
+
int g = group[i];
|
34 |
+
output[i] = priority[i]!=temp[g];
|
35 |
+
}
|
36 |
+
};
|
37 |
+
|
38 |
+
template<typename _Tg, typename _Tp>
|
39 |
+
__global__ void cuda_do_task_group_priority(
|
40 |
+
const torch::PackedTensorAccessor<_Tg,2,torch::RestrictPtrTraits> group,
|
41 |
+
const torch::PackedTensorAccessor<_Tp,2,torch::RestrictPtrTraits> priority,
|
42 |
+
const torch::PackedTensorAccessor<bool,2,torch::RestrictPtrTraits> value,
|
43 |
+
torch::PackedTensorAccessor<bool,2,torch::RestrictPtrTraits> result,
|
44 |
+
const _Tg NG)
|
45 |
+
{
|
46 |
+
const int NP = group.size(0);
|
47 |
+
const int NT = group.size(1);
|
48 |
+
const int p = blockIdx.x * blockDim.x + threadIdx.x;
|
49 |
+
if(p < NP)
|
50 |
+
{
|
51 |
+
extern __shared__ char _temp[];
|
52 |
+
auto temp = reinterpret_cast<_Tp*>(_temp);
|
53 |
+
temp += (threadIdx.x * NG);
|
54 |
+
for(_Tg g=0; g<NG; g++){
|
55 |
+
temp[g] = std::numeric_limits<_Tp>::max();
|
56 |
+
}
|
57 |
+
|
58 |
+
for(int t=0; t<NT; t++){
|
59 |
+
if(value[p][t]){
|
60 |
+
continue;
|
61 |
+
}
|
62 |
+
_Tg g = group[p][t];
|
63 |
+
_Tp _p = priority[p][t];
|
64 |
+
if(_p < temp[g]){
|
65 |
+
temp[g] = _p;
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
for(int t=0; t<NT; t++){
|
70 |
+
_Tg g = group[p][t];
|
71 |
+
if(priority[p][t]==temp[g]){
|
72 |
+
result[p][t] = false;
|
73 |
+
}
|
74 |
+
}
|
75 |
+
}
|
76 |
+
};
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
void task_group_priority_cuda(
|
81 |
+
int* group, int* priority, bool* value, bool* output,
|
82 |
+
const int batch_size, const int task_num, const int group_num, const int device)
|
83 |
+
{
|
84 |
+
const int shared_mem = group_num * sizeof(int);
|
85 |
+
|
86 |
+
GRL_CHECK_CUDA(cudaSetDevice(device));
|
87 |
+
|
88 |
+
task_group_priority_kernel<<<batch_size, 256, shared_mem>>>(
|
89 |
+
group, priority, value, output, batch_size, task_num, group_num);
|
90 |
+
|
91 |
+
GRL_CHECK_CUDA(cudaGetLastError());
|
92 |
+
};
|
93 |
+
|
csrc/task_group_priority.h
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "./common.h"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* tasks are divided into groups,
|
7 |
+
* tasks in a group are visited by it's priority.
|
8 |
+
* the min priority value of unvisited tasks in a group is computed,
|
9 |
+
* output is false, if the task's priority equal the computed min priority, otherwise output is true
|
10 |
+
*
|
11 |
+
* group: task's group, shape is (batch_size, task_num)
|
12 |
+
* priority: task's priority, shape is (batch_size, task_num)
|
13 |
+
* value: task is visited or not, shape is (batch_size, task_num)
|
14 |
+
*
|
15 |
+
* output: the result, shape is (batch_size, task_num)
|
16 |
+
*/
|
17 |
+
auto task_group_priority(
|
18 |
+
const torch::Tensor& group,
|
19 |
+
const torch::Tensor& priority,
|
20 |
+
const torch::Tensor& value) -> torch::Tensor;
|
21 |
+
|
22 |
+
void task_group_priority_cpu(
|
23 |
+
int* group, int* priority, bool* value, bool* ouput,
|
24 |
+
int batch_size, int task_num, int group_num);
|
25 |
+
|
26 |
+
void task_group_priority_cuda(
|
27 |
+
int* group, int* priority, bool* value, bool* ouput,
|
28 |
+
int batch_size, int task_num, int group_num, int device);
|
29 |
+
|
30 |
+
|
csrc/task_group_split.cpp
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "task_group_split.h"
|
2 |
+
|
3 |
+
void task_group_split_cpu(
|
4 |
+
int* group, bool* value, bool* output,
|
5 |
+
const int batch_size, const int task_num, const int group_num)
|
6 |
+
{
|
7 |
+
auto temp = torch::make_unique<bool[]>(group_num);
|
8 |
+
for(int b=0; b<batch_size; b++)
|
9 |
+
{
|
10 |
+
for(int i=0; i<group_num; i++){
|
11 |
+
temp[i] = false;
|
12 |
+
}
|
13 |
+
|
14 |
+
for(int i=0; i<task_num; i++){
|
15 |
+
if(value[i]){
|
16 |
+
int g = group[i];
|
17 |
+
temp[g] = true;
|
18 |
+
}
|
19 |
+
}
|
20 |
+
|
21 |
+
output[b] = false;
|
22 |
+
for(int i=0; i<task_num; i++){
|
23 |
+
int g = group[i];
|
24 |
+
if(temp[g] && !value[i]){
|
25 |
+
output[b] = true;
|
26 |
+
break;
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
group += task_num;
|
31 |
+
value += task_num;
|
32 |
+
}
|
33 |
+
};
|
34 |
+
|
35 |
+
|
36 |
+
auto task_group_split(
|
37 |
+
const Tensor& group, const Tensor& value) -> Tensor
|
38 |
+
{
|
39 |
+
auto device = group.device();
|
40 |
+
const int batch_size = group.size(0);
|
41 |
+
const int task_num = group.size(1);
|
42 |
+
const int group_num = group.max().item<int>() + 1;
|
43 |
+
const int _group_num = group.min().item<int>();
|
44 |
+
|
45 |
+
GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
|
46 |
+
|
47 |
+
GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
|
48 |
+
GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
|
49 |
+
|
50 |
+
auto output = torch::zeros({batch_size}, torch::dtype(torch::kBool).device(device));
|
51 |
+
|
52 |
+
switch(device.type())
|
53 |
+
{
|
54 |
+
case torch::kCPU:
|
55 |
+
task_group_split_cpu(group.data_ptr<int>(), value.data_ptr<bool>(),
|
56 |
+
output.data_ptr<bool>(), batch_size, task_num, group_num);
|
57 |
+
break;
|
58 |
+
#ifdef CUDA_FOUND
|
59 |
+
case torch::kCUDA:
|
60 |
+
task_group_split_cuda(group.data_ptr<int>(), value.data_ptr<bool>(),
|
61 |
+
output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
|
62 |
+
break;
|
63 |
+
#endif
|
64 |
+
default:
|
65 |
+
GRL_ERROR("unsupported device: %s", device.str().c_str());
|
66 |
+
}
|
67 |
+
|
68 |
+
return output;
|
69 |
+
};
|
csrc/task_group_split.cu
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "task_group_split.h"
|
2 |
+
|
3 |
+
__global__ void task_group_split_kernel(
|
4 |
+
int* group, bool* value, bool* output,
|
5 |
+
const int batch_size, const int task_num, const int group_num)
|
6 |
+
{
|
7 |
+
group += blockIdx.x * task_num;
|
8 |
+
value += blockIdx.x * task_num;
|
9 |
+
extern __shared__ bool temp[];
|
10 |
+
|
11 |
+
__shared__ bool split;
|
12 |
+
if(threadIdx.x == 0) split = false;
|
13 |
+
|
14 |
+
for(int i=threadIdx.x; i<group_num; i+=blockDim.x)
|
15 |
+
{
|
16 |
+
temp[i] = false;
|
17 |
+
}
|
18 |
+
|
19 |
+
__syncthreads();
|
20 |
+
|
21 |
+
for(int i=threadIdx.x; i<task_num; i+=blockDim.x)
|
22 |
+
{
|
23 |
+
int g = group[i];
|
24 |
+
if(value[i]) temp[g] = true;
|
25 |
+
}
|
26 |
+
|
27 |
+
__syncthreads();
|
28 |
+
|
29 |
+
for(int i=threadIdx.x; i<task_num; i+=blockDim.x)
|
30 |
+
{
|
31 |
+
int g = group[i];
|
32 |
+
if(temp[g] && !value[i]) split = true;
|
33 |
+
}
|
34 |
+
|
35 |
+
__syncthreads();
|
36 |
+
|
37 |
+
if(threadIdx.x == 0) output[blockIdx.x] = split;
|
38 |
+
};
|
39 |
+
|
40 |
+
void task_group_split_cuda(
|
41 |
+
int* group, bool* value, bool* output,
|
42 |
+
const int batch_size, const int task_num, const int group_num, const int device)
|
43 |
+
{
|
44 |
+
const int shared_mem = group_num * sizeof(bool);
|
45 |
+
|
46 |
+
GRL_CHECK_CUDA(cudaSetDevice(device));
|
47 |
+
|
48 |
+
task_group_split_kernel<<<batch_size, 256, shared_mem>>>(
|
49 |
+
group, value, output, batch_size, task_num, group_num);
|
50 |
+
|
51 |
+
GRL_CHECK_CUDA(cudaGetLastError());
|
52 |
+
};
|
53 |
+
|
csrc/task_group_split.h
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "./common.h"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* tasks are divided into groups,
|
7 |
+
* if tasks in a group are all visited or all not visited,
|
8 |
+
* output is is false, otherwise output is true
|
9 |
+
*
|
10 |
+
* group: task's group, shape is (batch_size, task_num)
|
11 |
+
* value: task is visited or not, shape is (batch_size, task_num)
|
12 |
+
*
|
13 |
+
* output: the result, shape is (batch_size,)
|
14 |
+
*/
|
15 |
+
auto task_group_split(const Tensor& group, const Tensor& value) -> Tensor;
|
16 |
+
|
17 |
+
void task_group_split_cpu(
|
18 |
+
int* group, bool* value, bool* output,
|
19 |
+
const int batch_size, const int task_num, const int group_num);
|
20 |
+
|
21 |
+
void task_group_split_cuda(
|
22 |
+
int* group, bool* value, bool* output,
|
23 |
+
const int batch_size, const int task_num, const int group_num, const int device);
|
24 |
+
|
examples/batching/batching.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from greedrl import Problem, Solver
|
4 |
+
from greedrl.feature import *
|
5 |
+
from greedrl.variable import *
|
6 |
+
|
7 |
+
features = [local_feature('task_area'),
|
8 |
+
local_feature('task_roadway'),
|
9 |
+
local_feature('task_area_group'),
|
10 |
+
sparse_local_feature('task_item_id', 'task_item_num'),
|
11 |
+
sparse_local_feature('task_item_owner_id', 'task_item_num'),
|
12 |
+
variable_feature('worker_task_item'),
|
13 |
+
variable_feature('worker_used_roadway'),
|
14 |
+
variable_feature('worker_used_area')]
|
15 |
+
|
16 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
17 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
18 |
+
feature_variable('task_item_id'),
|
19 |
+
feature_variable('task_item_num'),
|
20 |
+
feature_variable('task_item_owner_id'),
|
21 |
+
feature_variable('task_area'),
|
22 |
+
feature_variable('task_area_group'),
|
23 |
+
feature_variable('task_load'),
|
24 |
+
feature_variable('task_group'),
|
25 |
+
worker_variable('worker_load_limit'),
|
26 |
+
worker_variable('worker_area_limit'),
|
27 |
+
worker_variable('worker_area_group_limit'),
|
28 |
+
worker_task_item('worker_task_item', item_id='task_item_id', item_num='task_item_num'),
|
29 |
+
worker_task_item('worker_task_item_owner', item_id='task_item_owner_id', item_num='task_item_num'),
|
30 |
+
worker_used_resource('worker_used_load', task_require='task_load'),
|
31 |
+
worker_used_resource('worker_used_area', task_require='task_area'),
|
32 |
+
worker_used_resource('worker_used_roadway', task_require='task_roadway'),
|
33 |
+
worker_used_resource('worker_used_area_group', task_require='task_area_group')]
|
34 |
+
|
35 |
+
|
36 |
+
class Constraint:
|
37 |
+
|
38 |
+
def do_task(self):
|
39 |
+
return self.task_demand_this
|
40 |
+
|
41 |
+
def mask_worker_end(self):
|
42 |
+
return self.worker_used_load < self.worker_load_limit
|
43 |
+
|
44 |
+
def mask_task(self):
|
45 |
+
# 已经完成的任务
|
46 |
+
mask = self.task_demand_now <= 0
|
47 |
+
# mask |= task_group_priority(self.task_group, self.task_out_stock_time, mask)
|
48 |
+
|
49 |
+
NT = self.task_item_id.size(1)
|
50 |
+
worker_task_item = self.worker_task_item[:, None, :]
|
51 |
+
worker_task_item = worker_task_item.expand(-1, NT, -1)
|
52 |
+
task_item_in_worker = worker_task_item.gather(2, self.task_item_id.long())
|
53 |
+
task_item_in_worker = (task_item_in_worker > 0) & (self.task_item_num > 0)
|
54 |
+
|
55 |
+
worker_task_item_owner = self.worker_task_item_owner[:, None, :]
|
56 |
+
worker_task_item_owner = worker_task_item_owner.expand(-1, NT, -1)
|
57 |
+
task_item_owner_in_worker = worker_task_item_owner.gather(2, self.task_item_owner_id.long())
|
58 |
+
task_item_owner_in_worker = (task_item_owner_in_worker > 0) & (self.task_item_num > 0)
|
59 |
+
|
60 |
+
# 同一个sku,不同货主,不能在一个拣选单
|
61 |
+
mask |= torch.any(task_item_in_worker & ~task_item_owner_in_worker, 2)
|
62 |
+
|
63 |
+
worker_load_limit = self.worker_load_limit - self.worker_used_load
|
64 |
+
mask |= (self.task_load > worker_load_limit[:, None])
|
65 |
+
|
66 |
+
task_area = self.task_area + self.worker_used_area[:, None, :]
|
67 |
+
task_area_num = task_area.clamp(0, 1).sum(2, dtype=torch.int32)
|
68 |
+
mask |= (task_area_num > self.worker_area_limit[:, None])
|
69 |
+
|
70 |
+
tak_area_group = self.task_area_group + self.worker_used_area_group[:, None, :]
|
71 |
+
tak_area_group_num = tak_area_group.clamp(0, 1).sum(2, dtype=torch.int32)
|
72 |
+
mask |= (tak_area_group_num > self.worker_area_group_limit[:, None])
|
73 |
+
|
74 |
+
return mask
|
75 |
+
|
76 |
+
def finished(self):
|
77 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
78 |
+
|
79 |
+
|
80 |
+
class Objective:
|
81 |
+
|
82 |
+
def step_worker_end(self):
|
83 |
+
area_num = self.worker_used_area.clamp(0, 1).sum(1)
|
84 |
+
roadway_num = self.worker_used_roadway.clamp(0, 1).sum(1)
|
85 |
+
item_num = self.worker_task_item.clamp(0, 1).sum(1)
|
86 |
+
penalty = (self.worker_load_limit - self.worker_used_load) * 10
|
87 |
+
return area_num * 100 + roadway_num * 10 + item_num + penalty
|
88 |
+
|
89 |
+
|
90 |
+
def make_problem_from_json(data):
|
91 |
+
if isinstance(data, str):
|
92 |
+
data = json.loads(data)
|
93 |
+
problem = Problem()
|
94 |
+
problem.id = data["id"]
|
95 |
+
if 'uuid' in data:
|
96 |
+
problem.uuid = data["uuid"]
|
97 |
+
|
98 |
+
problem.task_item_id = torch.tensor(data["task_item_id"], dtype=torch.int32)
|
99 |
+
problem.task_item_owner_id = torch.tensor(data["task_item_owner_id"], dtype=torch.int32)
|
100 |
+
problem.task_item_num = torch.tensor(data["task_item_num"], dtype=torch.int32)
|
101 |
+
problem.task_area = torch.tensor(data["task_area"], dtype=torch.int32)
|
102 |
+
problem.task_roadway = torch.tensor(data["task_roadway"], dtype=torch.int32)
|
103 |
+
problem.task_out_stock_time = torch.tensor(data["task_out_stock_time"], dtype=torch.int32)
|
104 |
+
problem.task_area_group = torch.tensor(data["task_area_group"], dtype=torch.int32)
|
105 |
+
|
106 |
+
NT = problem.task_item_id.size(0)
|
107 |
+
problem.task_load = torch.ones(NT, dtype=torch.int32)
|
108 |
+
problem.task_group = torch.zeros(NT, dtype=torch.int32)
|
109 |
+
problem.task_demand = torch.ones(NT, dtype=torch.int32)
|
110 |
+
|
111 |
+
problem.worker_load_limit = torch.tensor(data["worker_load_limit"], dtype=torch.int32)
|
112 |
+
problem.worker_area_limit = torch.tensor(data["worker_area_limit"], dtype=torch.int32)
|
113 |
+
problem.worker_area_group_limit = torch.tensor(data["worker_area_group_limit"], dtype=torch.int32)
|
114 |
+
|
115 |
+
problem.features = features
|
116 |
+
problem.variables = variables
|
117 |
+
problem.constraint = Constraint
|
118 |
+
problem.objective = Objective
|
119 |
+
|
120 |
+
return problem
|
121 |
+
|
122 |
+
|
123 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
124 |
+
assert batch_size == 1
|
125 |
+
|
126 |
+
NT = task_count
|
127 |
+
problem_list = []
|
128 |
+
for i in range(batch_count):
|
129 |
+
problem = Problem()
|
130 |
+
problem.id = i
|
131 |
+
|
132 |
+
device = Solver().device
|
133 |
+
p = torch.ones(NT, 1000, dtype=torch.float32, device=device)
|
134 |
+
problem.task_item_id = torch.multinomial(p, 10).to(torch.int32).cpu()
|
135 |
+
problem.task_item_owner_id = torch.multinomial(p, 10).to(torch.int32).cpu()
|
136 |
+
problem.task_item_num = torch.randint(0, 5, (NT, 10), dtype=torch.int32)
|
137 |
+
problem.task_area = torch.randint(0, 5, (NT, 10), dtype=torch.int32).clamp(0, 1)
|
138 |
+
problem.task_roadway = torch.randint(0, 5, (NT, 200), dtype=torch.int32).clamp(0, 1)
|
139 |
+
problem.task_area_group = torch.randint(0, 5, (NT, 10), dtype=torch.int32).clamp(0, 1)
|
140 |
+
|
141 |
+
problem.task_load = torch.ones(NT, dtype=torch.int32)
|
142 |
+
problem.task_group = torch.zeros(NT, dtype=torch.int32)
|
143 |
+
problem.task_demand = torch.ones(NT, dtype=torch.int32)
|
144 |
+
|
145 |
+
problem.worker_load_limit = torch.tensor([20], dtype=torch.int32)
|
146 |
+
problem.worker_area_limit = torch.tensor([10], dtype=torch.int32)
|
147 |
+
problem.worker_area_group_limit = torch.tensor([10], dtype=torch.int32)
|
148 |
+
|
149 |
+
problem.features = features
|
150 |
+
problem.variables = variables
|
151 |
+
problem.constraint = Constraint
|
152 |
+
problem.objective = Objective
|
153 |
+
|
154 |
+
problem_list.append(problem)
|
155 |
+
|
156 |
+
return problem_list
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == '__main__':
|
160 |
+
import sys
|
161 |
+
import os.path as osp
|
162 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
163 |
+
import runner
|
164 |
+
|
165 |
+
runner.run(make_problem)
|
examples/cvrp/cvrp.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from greedrl.feature import *
|
2 |
+
from greedrl.variable import *
|
3 |
+
from greedrl import Problem
|
4 |
+
|
5 |
+
features = [continuous_feature('task_demand'),
|
6 |
+
continuous_feature('worker_weight_limit'),
|
7 |
+
continuous_feature('distance_matrix'),
|
8 |
+
variable_feature('distance_this_to_task'),
|
9 |
+
variable_feature('distance_task_to_end')]
|
10 |
+
|
11 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
12 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
13 |
+
feature_variable('task_weight'),
|
14 |
+
worker_variable('worker_weight_limit'),
|
15 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
16 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
17 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
18 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
19 |
+
|
20 |
+
|
21 |
+
class Constraint:
|
22 |
+
|
23 |
+
def do_task(self):
|
24 |
+
return self.task_demand_this
|
25 |
+
|
26 |
+
def mask_task(self):
|
27 |
+
# 已经完成的任务
|
28 |
+
mask = self.task_demand_now <= 0
|
29 |
+
# 车辆容量限制
|
30 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
31 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
32 |
+
return mask
|
33 |
+
|
34 |
+
def finished(self):
|
35 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
36 |
+
|
37 |
+
|
38 |
+
class Objective:
|
39 |
+
|
40 |
+
def step_worker_end(self):
|
41 |
+
return self.distance_last_to_this
|
42 |
+
|
43 |
+
def step_task(self):
|
44 |
+
return self.distance_last_to_this
|
45 |
+
|
46 |
+
|
47 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
48 |
+
assert task_count in (100, 1000, 2000, 5000)
|
49 |
+
|
50 |
+
weight_limit = 50
|
51 |
+
problem_list = []
|
52 |
+
for i in range(batch_count):
|
53 |
+
problem = Problem(True)
|
54 |
+
problem.id = torch.arange(batch_size) + i * batch_size;
|
55 |
+
|
56 |
+
problem.worker_weight_limit = torch.full((batch_size, 1), weight_limit, dtype=torch.int32)
|
57 |
+
|
58 |
+
N = task_count
|
59 |
+
problem.task_demand = torch.randint(1, 10, (batch_size, N), dtype=torch.int32)
|
60 |
+
problem.task_demand_x = problem.task_demand.float() / weight_limit
|
61 |
+
|
62 |
+
# 一个单位的task_demand的重量
|
63 |
+
problem.task_weight = torch.ones(batch_size, N, dtype=torch.int32)
|
64 |
+
|
65 |
+
loc = torch.rand(batch_size, N + 1, 2, dtype=torch.float32)
|
66 |
+
problem.task_location = loc[:, 1:, :]
|
67 |
+
problem.worker_location = loc[:, 0:1, :]
|
68 |
+
|
69 |
+
distance_matrix = torch.norm(loc[:, :, None, :] - loc[:, None, :, :], dim=3)
|
70 |
+
problem.distance_matrix = distance_matrix
|
71 |
+
|
72 |
+
problem.features = features
|
73 |
+
problem.variables = variables
|
74 |
+
problem.constraint = Constraint
|
75 |
+
problem.objective = Objective
|
76 |
+
|
77 |
+
problem_list.append(problem)
|
78 |
+
|
79 |
+
return problem_list
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
import sys
|
84 |
+
import os.path as osp
|
85 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
86 |
+
import runner
|
87 |
+
|
88 |
+
runner.run(make_problem)
|
examples/cvrp/orts.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import utils
|
6 |
+
import multiprocessing as mp
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
|
9 |
+
from ortools.constraint_solver import pywrapcp
|
10 |
+
from ortools.constraint_solver import routing_enums_pb2
|
11 |
+
|
12 |
+
|
13 |
+
def solve(problem, i, max_time):
|
14 |
+
scale = 100000
|
15 |
+
size = problem.task_demand.size(1)
|
16 |
+
demand = [0] + problem.task_demand[i].tolist()
|
17 |
+
capacity = problem.worker_weight_limit[i].tolist()
|
18 |
+
distance = (problem.distance_matrix[i] * scale + 0.5).to(torch.int32).tolist()
|
19 |
+
|
20 |
+
queue = mp.Queue()
|
21 |
+
p = mp.Process(target=do_solve, args=(size, demand, capacity, distance, max_time, queue))
|
22 |
+
p.start()
|
23 |
+
p.join()
|
24 |
+
|
25 |
+
return queue.get() / scale, queue.get()
|
26 |
+
|
27 |
+
|
28 |
+
def do_solve(size, demand, capacity, distance, max_time, queue):
|
29 |
+
capacity = capacity * size
|
30 |
+
|
31 |
+
manager = pywrapcp.RoutingIndexManager(size + 1, size, 0)
|
32 |
+
routing = pywrapcp.RoutingModel(manager)
|
33 |
+
|
34 |
+
def distance_callback(from_index, to_index):
|
35 |
+
from_node = manager.IndexToNode(from_index)
|
36 |
+
to_node = manager.IndexToNode(to_index)
|
37 |
+
return distance[from_node][to_node]
|
38 |
+
|
39 |
+
distance_callback_index = routing.RegisterTransitCallback(distance_callback)
|
40 |
+
routing.SetArcCostEvaluatorOfAllVehicles(distance_callback_index)
|
41 |
+
|
42 |
+
def demand_callback(from_index):
|
43 |
+
from_node = manager.IndexToNode(from_index)
|
44 |
+
return demand[from_node]
|
45 |
+
|
46 |
+
demand_callback_index = routing.RegisterUnaryTransitCallback(demand_callback)
|
47 |
+
routing.AddDimensionWithVehicleCapacity(demand_callback_index, 0, capacity, True, 'capacity')
|
48 |
+
|
49 |
+
params = pywrapcp.DefaultRoutingSearchParameters()
|
50 |
+
params.first_solution_strategy = (routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC)
|
51 |
+
params.local_search_metaheuristic = (routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH)
|
52 |
+
params.time_limit.seconds = max_time
|
53 |
+
|
54 |
+
start_time = time.time()
|
55 |
+
solution = routing.SolveWithParameters(params)
|
56 |
+
spent_time = time.time() - start_time
|
57 |
+
|
58 |
+
queue.put(solution.ObjectiveValue())
|
59 |
+
queue.put(spent_time)
|
60 |
+
|
61 |
+
|
62 |
+
def run_orts(task, max_time):
|
63 |
+
problem, i = task
|
64 |
+
return solve(problem, i, max_time)
|
65 |
+
|
66 |
+
|
67 |
+
def main(args):
|
68 |
+
print("args: {}".format(vars(args)))
|
69 |
+
problem_size = args.problem_size
|
70 |
+
problem_count = args.problem_count
|
71 |
+
batch_size = args.batch_size
|
72 |
+
|
73 |
+
assert problem_count % batch_size == 0
|
74 |
+
batch_count = problem_count // batch_size
|
75 |
+
problem_list = utils.make_problem(batch_count, batch_size, problem_size)
|
76 |
+
|
77 |
+
executor = ThreadPoolExecutor(max_workers=args.threads)
|
78 |
+
task_list = [(p, i) for p in problem_list for i in range(batch_size)]
|
79 |
+
|
80 |
+
total_cost = 0
|
81 |
+
total_time = 0
|
82 |
+
for cost, elapse in executor.map(run_orts, task_list, [args.max_time] * problem_count):
|
83 |
+
total_cost += cost
|
84 |
+
total_time += elapse
|
85 |
+
|
86 |
+
avg_cost = total_cost / problem_count
|
87 |
+
avg_time = total_time / problem_count
|
88 |
+
print()
|
89 |
+
print("-----------------------------------------------------")
|
90 |
+
print("avg_cost: {:.4f}".format(avg_cost))
|
91 |
+
print("avg_time: {:.6f}s".format(avg_time))
|
92 |
+
print("total_count: {}".format(problem_count))
|
93 |
+
print("-----------------------------------------------------\n")
|
94 |
+
sys.stdout.flush()
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
parser = argparse.ArgumentParser()
|
99 |
+
parser.add_argument('--threads', default=20, type=int, help='number of threads')
|
100 |
+
parser.add_argument('--max_time', default=60, type=int, help='the time limit for the search in seconds')
|
101 |
+
|
102 |
+
parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
|
103 |
+
parser.add_argument('--problem_count', default=128, type=int, help='total number of generated problem instances')
|
104 |
+
parser.add_argument('--batch_size', default=128, type=int, help='batch size for feedforwarding')
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
main(args)
|
examples/cvrp/solve.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import utils
|
7 |
+
from greedrl import Solver
|
8 |
+
|
9 |
+
torch.set_num_threads(1)
|
10 |
+
torch.set_num_interop_threads(1)
|
11 |
+
|
12 |
+
|
13 |
+
def do_solve(args):
|
14 |
+
print("args: {}".format(vars(args)))
|
15 |
+
|
16 |
+
problem_size = args.problem_size
|
17 |
+
problem_count = args.problem_count
|
18 |
+
batch_size = args.batch_size
|
19 |
+
assert problem_count % batch_size == 0
|
20 |
+
batch_count = problem_count // batch_size
|
21 |
+
|
22 |
+
problem_list = utils.make_problem(batch_count, batch_size, problem_size)
|
23 |
+
|
24 |
+
solver = Solver(device=args.device)
|
25 |
+
|
26 |
+
model_path = os.path.join('./', args.model_name)
|
27 |
+
solver.load_agent(model_path)
|
28 |
+
|
29 |
+
total_cost = 0
|
30 |
+
|
31 |
+
if solver.device.type == 'cuda':
|
32 |
+
torch.cuda.synchronize()
|
33 |
+
|
34 |
+
start_time = time.time()
|
35 |
+
for problem in problem_list:
|
36 |
+
solution = solver.solve(problem, greedy=False, batch_size=batch_size)
|
37 |
+
total_cost += solution.cost.sum().item()
|
38 |
+
|
39 |
+
if solver.device.type == 'cuda':
|
40 |
+
torch.cuda.synchronize()
|
41 |
+
|
42 |
+
total_time = time.time() - start_time
|
43 |
+
|
44 |
+
avg_cost = total_cost / problem_count
|
45 |
+
avg_time = total_time / problem_count
|
46 |
+
print()
|
47 |
+
print("-----------------------------------------------------")
|
48 |
+
print("avg_cost: {:.4f}".format(avg_cost))
|
49 |
+
print("avg_time: {:.6f}s".format(avg_time))
|
50 |
+
print("total_count: {}".format(problem_count))
|
51 |
+
print("-----------------------------------------------------\n")
|
52 |
+
sys.stdout.flush()
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
parser = argparse.ArgumentParser()
|
57 |
+
parser.add_argument('--device', default='cpu', choices=['cpu', 'cuda'], help="choose a device")
|
58 |
+
parser.add_argument('--model_name', default='cvrp_100.pt', choices=['cvrp_100.pt', 'cvrp_1000.pt', 'cvrp_2000.pt', 'cvrp_5000.pt'], help="choose a model")
|
59 |
+
parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
|
60 |
+
parser.add_argument('--problem_count', default=128, type=int, help='total number of generated problem instances')
|
61 |
+
parser.add_argument('--batch_size', default=128, type=int, help='batch size for feedforwarding')
|
62 |
+
|
63 |
+
args = parser.parse_args()
|
64 |
+
do_solve(args)
|
65 |
+
|
examples/cvrp/train.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
import utils
|
7 |
+
from greedrl import Solver
|
8 |
+
|
9 |
+
|
10 |
+
def do_train(args, rank):
|
11 |
+
world_size = args.world_size
|
12 |
+
model_filename = args.model_filename
|
13 |
+
problem_size = args.problem_size
|
14 |
+
batch_size = args.batch_size
|
15 |
+
|
16 |
+
index = model_filename.rfind('.')
|
17 |
+
if world_size > 1:
|
18 |
+
stdout_filename = '{}_r{}.log'.format(model_filename[0:index], rank)
|
19 |
+
else:
|
20 |
+
stdout_filename = '{}.log'.format(model_filename[0:index])
|
21 |
+
|
22 |
+
stdout = open(stdout_filename, 'a')
|
23 |
+
sys.stdout = stdout
|
24 |
+
sys.stderr = stdout
|
25 |
+
|
26 |
+
print("args: {}".format(vars(args)))
|
27 |
+
if world_size > 1:
|
28 |
+
dist.init_process_group('NCCL', init_method='tcp://127.0.0.1:29500',
|
29 |
+
rank=rank, world_size=world_size)
|
30 |
+
|
31 |
+
problem_batch_size = 8
|
32 |
+
batch_count = 0
|
33 |
+
if problem_size == 100:
|
34 |
+
batch_count = math.ceil(10000 / problem_batch_size)
|
35 |
+
elif problem_size == 1000:
|
36 |
+
batch_count = math.ceil(200 / problem_batch_size)
|
37 |
+
elif problem_size == 2000:
|
38 |
+
batch_count = math.ceil(100 / problem_batch_size)
|
39 |
+
elif problem_size == 5000:
|
40 |
+
batch_count = math.ceil(10 / problem_batch_size)
|
41 |
+
else:
|
42 |
+
raise Exception("unsupported problem size: {}".format(problem_size))
|
43 |
+
|
44 |
+
nn_args = {
|
45 |
+
'encode_norm': 'instance',
|
46 |
+
'encode_layers': 6,
|
47 |
+
'decode_rnn': 'LSTM'
|
48 |
+
}
|
49 |
+
|
50 |
+
device = None if world_size == 1 else 'cuda:{}'.format(rank)
|
51 |
+
solver = Solver(device, nn_args)
|
52 |
+
|
53 |
+
train_dataset = utils.Dataset(None, problem_batch_size, problem_size)
|
54 |
+
valid_dataset = utils.Dataset(batch_count, problem_batch_size, problem_size)
|
55 |
+
|
56 |
+
solver.train(model_filename, train_dataset, valid_dataset,
|
57 |
+
train_dataset_workers=5,
|
58 |
+
batch_size=batch_size,
|
59 |
+
memopt=10,
|
60 |
+
topk_size=1,
|
61 |
+
init_lr=1e-4,
|
62 |
+
valid_steps=500,
|
63 |
+
warmup_steps=0)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
|
68 |
+
parser = argparse.ArgumentParser()
|
69 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
70 |
+
parser.add_argument('--model_filename', type=str, help='model file name')
|
71 |
+
parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
|
72 |
+
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
processes = []
|
77 |
+
for rank in range(args.world_size):
|
78 |
+
p = mp.Process(target=do_train, args=(args, rank))
|
79 |
+
p.start()
|
80 |
+
processes.append(p)
|
81 |
+
|
82 |
+
for p in processes:
|
83 |
+
p.join()
|
examples/cvrp/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from greedrl.feature import *
|
2 |
+
from cvrp import make_problem as make_cvrp_problem
|
3 |
+
from torch.utils.data import Dataset, IterableDataset, DataLoader
|
4 |
+
|
5 |
+
|
6 |
+
def make_problem(batch_count, batch_size, task_count):
|
7 |
+
features = [continuous_feature('task_demand_x'),
|
8 |
+
continuous_feature('distance_matrix')]
|
9 |
+
|
10 |
+
problem_list = make_cvrp_problem(batch_count, batch_size, task_count)
|
11 |
+
for problem in problem_list:
|
12 |
+
problem.features = features
|
13 |
+
|
14 |
+
return problem_list
|
15 |
+
|
16 |
+
|
17 |
+
class Dataset(IterableDataset):
|
18 |
+
def __init__(self, batch_count, batch_size, task_count):
|
19 |
+
self._batch_size = batch_size
|
20 |
+
self._task_count = task_count
|
21 |
+
self._batch_count = batch_count
|
22 |
+
self._index = 0
|
23 |
+
|
24 |
+
def __iter__(self):
|
25 |
+
self._index = 0
|
26 |
+
return self
|
27 |
+
|
28 |
+
def __next__(self):
|
29 |
+
if self._batch_count is not None \
|
30 |
+
and self._index >= self._batch_count:
|
31 |
+
raise StopIteration()
|
32 |
+
|
33 |
+
p = make_problem(1, self._batch_size, self._task_count)[0]
|
34 |
+
self._index += 1
|
35 |
+
return p
|
36 |
+
|
37 |
+
|
38 |
+
def write_vrplib(filename, name, size, demand, capacity, location):
|
39 |
+
with open(filename, 'w') as f:
|
40 |
+
f.write('\n'.join([
|
41 |
+
"{} : {}".format(k, v)
|
42 |
+
for k, v in (
|
43 |
+
('NAME', name),
|
44 |
+
('TYPE', 'CVRP'),
|
45 |
+
('COMMENT', 'NONE'),
|
46 |
+
('DIMENSION', size + 1),
|
47 |
+
('EDGE_WEIGHT_TYPE', 'EUC_2D'),
|
48 |
+
('CAPACITY', capacity)
|
49 |
+
)
|
50 |
+
]))
|
51 |
+
|
52 |
+
f.write('\n')
|
53 |
+
f.write('NODE_COORD_SECTION\n')
|
54 |
+
|
55 |
+
f.write('\n'.join(['{}\t{}\t{}'.format(i + 1, x, y) for i, (x, y) in enumerate(location)]))
|
56 |
+
|
57 |
+
f.write('\n')
|
58 |
+
f.write('DEMAND_SECTION\n')
|
59 |
+
f.write('\n'.join(['{}\t{}'.format(i + 1, d) for i, d in enumerate([0] + demand)]))
|
60 |
+
|
61 |
+
f.write('\n')
|
62 |
+
f.write('DEPOT_SECTION\n')
|
63 |
+
f.write('1\n')
|
64 |
+
f.write('-1\n')
|
65 |
+
f.write('EOF\n')
|
examples/dpdp/dpdp.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from greedrl.feature import *
|
4 |
+
from greedrl.variable import *
|
5 |
+
from greedrl.function import *
|
6 |
+
from greedrl import Problem
|
7 |
+
|
8 |
+
features = [local_category('task_order'),
|
9 |
+
global_category('task_type', 2),
|
10 |
+
global_category('task_new_order', 2),
|
11 |
+
variable_feature('time_this_to_task'),
|
12 |
+
continuous_feature('x_time_matrix'),
|
13 |
+
continuous_feature('task_due_time_x'),
|
14 |
+
continuous_feature('worker_task_mask')]
|
15 |
+
|
16 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
17 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
18 |
+
task_variable('task_pickup_this', feature='task_pickup'),
|
19 |
+
task_variable('task_due_time_this', feature='task_due_time'),
|
20 |
+
feature_variable('task_order', feature='task_order'),
|
21 |
+
feature_variable('task_type', feature='task_type'),
|
22 |
+
feature_variable('task_new_pickup', feature='task_new_pickup'),
|
23 |
+
feature_variable('worker_task_mask', feature='worker_task_mask'),
|
24 |
+
worker_count_now('worker_count_now', feature='worker_count'),
|
25 |
+
worker_variable('worker_min_old_task_this', feature='worker_min_old_task'),
|
26 |
+
worker_variable('worker_max_new_order_this', feature='worker_max_new_order'),
|
27 |
+
worker_variable('worker_task_mask_this', feature='worker_task_mask'),
|
28 |
+
worker_used_resource('worker_used_old_task', task_require='task_old'),
|
29 |
+
worker_used_resource('worker_used_new_order', task_require='task_new_pickup'),
|
30 |
+
worker_used_resource('worker_used_time', edge_require='time_matrix'),
|
31 |
+
edge_variable('time_this_to_task', feature='x_time_matrix', this_to_task=True)]
|
32 |
+
|
33 |
+
|
34 |
+
class Constraint:
|
35 |
+
|
36 |
+
def do_task(self):
|
37 |
+
return self.task_demand_this
|
38 |
+
|
39 |
+
def mask_worker_start(self):
|
40 |
+
mask = self.worker_count_now <= 0
|
41 |
+
|
42 |
+
finished = self.task_demand_now <= 0
|
43 |
+
worker_task_mask = self.worker_task_mask | finished[:, None, :]
|
44 |
+
mask |= torch.all(worker_task_mask, 2)
|
45 |
+
|
46 |
+
return mask
|
47 |
+
|
48 |
+
def mask_worker_end(self):
|
49 |
+
mask = self.worker_used_old_task < self.worker_min_old_task_this
|
50 |
+
mask |= task_group_split(self.task_order, self.task_demand_now <= 0)
|
51 |
+
return mask
|
52 |
+
|
53 |
+
def mask_task(self):
|
54 |
+
mask = self.task_demand_now <= 0
|
55 |
+
|
56 |
+
mask |= task_group_priority(self.task_order, self.task_type, mask)
|
57 |
+
|
58 |
+
worker_max_new_order = self.worker_max_new_order_this - self.worker_used_new_order
|
59 |
+
mask |= self.task_new_pickup > worker_max_new_order[:, None]
|
60 |
+
|
61 |
+
mask |= self.worker_task_mask_this
|
62 |
+
|
63 |
+
return mask
|
64 |
+
|
65 |
+
def finished(self):
|
66 |
+
worker_mask = self.worker_count_now <= 0
|
67 |
+
task_mask = self.task_demand_now <= 0
|
68 |
+
worker_task_mask = worker_mask[:, :, None] | task_mask[:, None, :]
|
69 |
+
|
70 |
+
worker_task_mask |= self.worker_task_mask
|
71 |
+
batch_size = worker_task_mask.size(0)
|
72 |
+
worker_task_mask = worker_task_mask.view(batch_size, -1)
|
73 |
+
return worker_task_mask.all(1)
|
74 |
+
|
75 |
+
|
76 |
+
class Objective:
|
77 |
+
|
78 |
+
def step_task(self):
|
79 |
+
over_time = (self.worker_used_time - self.task_due_time_this).clamp(min=0)
|
80 |
+
pickup_time = self.worker_used_time * self.task_pickup_this
|
81 |
+
return self.worker_used_time + over_time + pickup_time
|
82 |
+
|
83 |
+
def step_finish(self):
|
84 |
+
return self.task_demand_now.sum(1) * 1000
|
85 |
+
|
86 |
+
|
87 |
+
def preprocess(problem):
|
88 |
+
NW, NT = problem.worker_task_mask.size()
|
89 |
+
|
90 |
+
worker_task_old = torch.ones(NW, NT, dtype=torch.int32)
|
91 |
+
new_task_mask = problem.task_new_order[None, :].expand(NW, NT)
|
92 |
+
worker_task_old[new_task_mask] = 0
|
93 |
+
worker_task_old[problem.worker_task_mask] = 0
|
94 |
+
assert torch.all(worker_task_old.sum(0) <= 1)
|
95 |
+
problem.worker_min_old_task = worker_task_old.sum(1)
|
96 |
+
|
97 |
+
problem.worker_count = torch.ones(NW, dtype=torch.int32)
|
98 |
+
problem.task_demand = torch.ones(NT, dtype=torch.int32)
|
99 |
+
problem.task_pickup = (problem.task_type == 0).to(torch.int32)
|
100 |
+
|
101 |
+
task_old = torch.ones(NT, dtype=torch.int32)
|
102 |
+
task_old[problem.task_new_order] = 0
|
103 |
+
problem.task_old = task_old
|
104 |
+
|
105 |
+
task_new_pickup = torch.ones(NT, dtype=torch.int32)
|
106 |
+
task_new_pickup[problem.task_type >= 1] = 0
|
107 |
+
task_new_pickup[~problem.task_new_order] = 0
|
108 |
+
problem.task_new_pickup = task_new_pickup
|
109 |
+
|
110 |
+
problem.task_due_time_x = problem.task_due_time.float() / 900
|
111 |
+
problem.x_time_matrix = problem.time_matrix.float() / 900
|
112 |
+
|
113 |
+
problem.features = features
|
114 |
+
problem.variables = variables
|
115 |
+
problem.constraint = Constraint
|
116 |
+
problem.objective = Objective
|
117 |
+
|
118 |
+
return problem
|
119 |
+
|
120 |
+
|
121 |
+
def make_problem_from_json(data):
|
122 |
+
data = json.loads(data)
|
123 |
+
|
124 |
+
problem = Problem()
|
125 |
+
|
126 |
+
problem.id = data['id']
|
127 |
+
problem.task_order = torch.tensor(data['task_order'], dtype=torch.int32)
|
128 |
+
problem.task_type = torch.tensor(data['task_type'], dtype=torch.int32)
|
129 |
+
problem.task_new_order = torch.tensor(data['task_new_order'], dtype=torch.bool)
|
130 |
+
problem.task_due_time = torch.tensor(data['task_due_time'], dtype=torch.int32)
|
131 |
+
|
132 |
+
problem.worker_max_new_order = torch.tensor(data['worker_max_new_order'], dtype=torch.int32)
|
133 |
+
problem.worker_task_mask = torch.tensor(data['worker_task_mask'], dtype=torch.bool)
|
134 |
+
problem.time_matrix = torch.tensor(data['time_matrix'], dtype=torch.int32)
|
135 |
+
|
136 |
+
NW, NT = problem.worker_task_mask.size()
|
137 |
+
|
138 |
+
assert problem.task_order.size() == (NT,), "task_order size error"
|
139 |
+
assert problem.task_type.size() == (NT,), "task_type size error"
|
140 |
+
assert problem.task_new_order.size() == (NT,), "task_new_order size error"
|
141 |
+
assert problem.task_due_time.size() == (NT,), "task_due_time size error"
|
142 |
+
assert problem.worker_max_new_order.size() == (NW,), "worker_max_new_order size error"
|
143 |
+
assert problem.time_matrix.size() == (NW + NT, NW + NT), "time_matrix size error"
|
144 |
+
|
145 |
+
return preprocess(problem)
|
146 |
+
|
147 |
+
|
148 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
149 |
+
assert batch_size == 1
|
150 |
+
assert task_count == 100
|
151 |
+
|
152 |
+
NW = 100
|
153 |
+
NT = task_count
|
154 |
+
NO = NT // 2 # 订单数, 一个订单有pickup, delivery两个任务
|
155 |
+
problem_list = []
|
156 |
+
for i in range(batch_count):
|
157 |
+
problem = Problem()
|
158 |
+
|
159 |
+
# user-provided data
|
160 |
+
problem.worker_max_new_order = torch.full((NW,), 2, dtype=torch.int32)
|
161 |
+
|
162 |
+
task_order = torch.arange(NO, dtype=torch.int32)
|
163 |
+
problem.task_order = torch.cat([task_order, task_order], 0)
|
164 |
+
|
165 |
+
task_type = torch.zeros(NO, dtype=torch.int32)
|
166 |
+
problem.task_type = torch.cat([task_type, task_type + 1], 0)
|
167 |
+
|
168 |
+
problem.task_new_order = torch.ones(NT, dtype=torch.bool)
|
169 |
+
|
170 |
+
task_due_time = torch.randint(1000, 1800, (NO,), dtype=torch.int32)
|
171 |
+
problem.task_due_time = torch.cat([task_due_time, task_due_time + 1800], 0)
|
172 |
+
|
173 |
+
worker_task_mask = torch.rand(NW, NO) < 0.9
|
174 |
+
problem.worker_task_mask = torch.cat([worker_task_mask, worker_task_mask], 1)
|
175 |
+
|
176 |
+
loc = torch.rand(NW + NT, 2, dtype=torch.float32)
|
177 |
+
time_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
|
178 |
+
problem.time_matrix = time_matrix.to(torch.int32)
|
179 |
+
|
180 |
+
problem_list.append(preprocess(problem))
|
181 |
+
|
182 |
+
return problem_list
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
import sys
|
187 |
+
import os.path as osp
|
188 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
189 |
+
import runner
|
190 |
+
|
191 |
+
runner.run(make_problem)
|
examples/pdptw/pdptw.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from greedrl.feature import *
|
2 |
+
from greedrl.variable import *
|
3 |
+
from greedrl.function import *
|
4 |
+
from greedrl import Problem
|
5 |
+
|
6 |
+
features = [local_category('task_group'),
|
7 |
+
global_category('task_priority', 2),
|
8 |
+
variable_feature('distance_this_to_task'),
|
9 |
+
variable_feature('distance_task_to_end')]
|
10 |
+
|
11 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
12 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
13 |
+
feature_variable('task_weight'),
|
14 |
+
feature_variable('task_group'),
|
15 |
+
feature_variable('task_priority'),
|
16 |
+
feature_variable('task_due_time2', feature='task_due_time'),
|
17 |
+
task_variable('task_due_time'),
|
18 |
+
task_variable('task_service_time'),
|
19 |
+
task_variable('task_due_time_penalty'),
|
20 |
+
worker_variable('worker_basic_cost'),
|
21 |
+
worker_variable('worker_distance_cost'),
|
22 |
+
worker_variable('worker_due_time'),
|
23 |
+
worker_variable('worker_weight_limit'),
|
24 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
25 |
+
worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
|
26 |
+
'worker_ready_time'),
|
27 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
28 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
29 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
30 |
+
|
31 |
+
|
32 |
+
class Constraint:
|
33 |
+
|
34 |
+
def do_task(self):
|
35 |
+
return self.task_demand_this
|
36 |
+
|
37 |
+
def mask_worker_end(self):
|
38 |
+
return task_group_split(self.task_group, self.task_demand_now <= 0)
|
39 |
+
|
40 |
+
def mask_task(self):
|
41 |
+
mask = self.task_demand_now <= 0
|
42 |
+
mask |= task_group_priority(self.task_group, self.task_priority, mask)
|
43 |
+
|
44 |
+
worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
|
45 |
+
mask |= (worker_used_time > self.task_due_time2) & (self.task_priority == 0)
|
46 |
+
|
47 |
+
# 容量约束
|
48 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
49 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
50 |
+
return mask
|
51 |
+
|
52 |
+
def finished(self):
|
53 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
54 |
+
|
55 |
+
|
56 |
+
class Objective:
|
57 |
+
|
58 |
+
def step_worker_start(self):
|
59 |
+
return self.worker_basic_cost
|
60 |
+
|
61 |
+
def step_worker_end(self):
|
62 |
+
feasible = self.worker_used_time <= self.worker_due_time
|
63 |
+
return self.distance_last_to_this * self.worker_distance_cost, feasible
|
64 |
+
|
65 |
+
def step_task(self):
|
66 |
+
worker_used_time = self.worker_used_time - self.task_service_time
|
67 |
+
feasible = worker_used_time <= self.task_due_time
|
68 |
+
feasible &= worker_used_time <= self.worker_due_time
|
69 |
+
cost = self.distance_last_to_this * self.worker_distance_cost
|
70 |
+
return torch.where(feasible, cost, cost + self.task_due_time_penalty), feasible
|
71 |
+
|
72 |
+
|
73 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
74 |
+
assert batch_size == 1
|
75 |
+
|
76 |
+
N = task_count // 2 # 订单数, 一个订单有pickup, delivery两个任务
|
77 |
+
problem_list = []
|
78 |
+
for i in range(batch_count):
|
79 |
+
problem = Problem()
|
80 |
+
problem.id = i
|
81 |
+
|
82 |
+
problem.worker_weight_limit = torch.tensor([50], dtype=torch.float32)
|
83 |
+
problem.worker_ready_time = torch.tensor([0], dtype=torch.float32)
|
84 |
+
problem.worker_due_time = torch.tensor([1000000], dtype=torch.float32)
|
85 |
+
problem.worker_basic_cost = torch.tensor([100], dtype=torch.float32)
|
86 |
+
problem.worker_distance_cost = torch.tensor([1], dtype=torch.float32)
|
87 |
+
|
88 |
+
task_demand = torch.randint(1, 10, (N,), dtype=torch.int32)
|
89 |
+
problem.task_demand = torch.cat([task_demand, task_demand], 0)
|
90 |
+
|
91 |
+
task_weight = torch.ones(N, dtype=torch.float32)
|
92 |
+
problem.task_weight = torch.cat([task_weight, task_weight * -1], 0)
|
93 |
+
|
94 |
+
task_group = torch.arange(N, dtype=torch.int32)
|
95 |
+
problem.task_group = torch.cat([task_group, task_group], 0)
|
96 |
+
|
97 |
+
task_priority = torch.zeros(N, dtype=torch.int32)
|
98 |
+
problem.task_priority = torch.cat([task_priority, task_priority + 1], 0)
|
99 |
+
|
100 |
+
task_ready_time = torch.zeros(N, dtype=torch.float32)
|
101 |
+
problem.task_ready_time = torch.cat([task_ready_time, task_ready_time], 0)
|
102 |
+
|
103 |
+
task_due_time = torch.randint(10000, 100000, (N,), dtype=torch.float32)
|
104 |
+
problem.task_due_time = torch.cat([task_due_time, task_due_time * 2], 0)
|
105 |
+
|
106 |
+
task_service_time = torch.zeros(N, dtype=torch.float32)
|
107 |
+
problem.task_service_time = torch.cat([task_service_time, task_service_time])
|
108 |
+
|
109 |
+
task_due_time_penalty = torch.ones(N, dtype=torch.float32)
|
110 |
+
problem.task_due_time_penalty = torch.cat([task_due_time_penalty, task_due_time_penalty])
|
111 |
+
|
112 |
+
loc = torch.rand(N + 1, 2, dtype=torch.float32)
|
113 |
+
distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
|
114 |
+
distance_matrix = distance_matrix.to(torch.float32)
|
115 |
+
index = torch.cat([torch.zeros(N + 1, dtype=torch.int64), torch.arange(N, dtype=torch.int64) + 1])
|
116 |
+
index1 = index[:, None]
|
117 |
+
index2 = index[None, :]
|
118 |
+
problem.distance_matrix = distance_matrix[index1, index2]
|
119 |
+
|
120 |
+
problem.features = features
|
121 |
+
problem.variables = variables
|
122 |
+
problem.constraint = Constraint
|
123 |
+
problem.objective = Objective
|
124 |
+
|
125 |
+
problem_list.append(problem)
|
126 |
+
|
127 |
+
return problem_list
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
import sys
|
132 |
+
import os.path as osp
|
133 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
134 |
+
import runner
|
135 |
+
|
136 |
+
runner.run(make_problem)
|
examples/runner.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from greedrl import Problem, Solution, Solver
|
7 |
+
|
8 |
+
|
9 |
+
def run(make_problem, mask_task_ratio=0.1):
|
10 |
+
random.seed(123)
|
11 |
+
torch.manual_seed(123)
|
12 |
+
problem_list = make_problem(1)
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description="")
|
15 |
+
parser.add_argument('--device', default=None, type=str)
|
16 |
+
parser.add_argument('--batch_size', default=32, type=int)
|
17 |
+
parser.add_argument('--agent_file', default=None, type=str)
|
18 |
+
parser.add_argument('--valid_steps', default=5, type=int)
|
19 |
+
parser.add_argument('--max_steps', default=10000000, type=int)
|
20 |
+
|
21 |
+
args, _ = parser.parse_known_args()
|
22 |
+
for k, v in args.__dict__.items():
|
23 |
+
print("arg: {} = {}".format(k, v))
|
24 |
+
|
25 |
+
# rl train
|
26 |
+
solver = Solver(device=args.device)
|
27 |
+
solver.train(args.agent_file, problem_list, problem_list,
|
28 |
+
batch_size=args.batch_size, valid_steps=args.valid_steps, max_steps=args.max_steps)
|
29 |
+
# predict
|
30 |
+
solver = Solver(device=args.device)
|
31 |
+
if args.agent_file is not None:
|
32 |
+
solver.load_agent(args.agent_file)
|
33 |
+
|
34 |
+
print("solve ...")
|
35 |
+
start = time.time()
|
36 |
+
for problem in problem_list:
|
37 |
+
solver.solve(problem, batch_size=args.batch_size)
|
38 |
+
print("time: {}s".format(time.time() - start))
|
examples/sdvrp/sdvrp.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from greedrl.feature import *
|
2 |
+
from greedrl.variable import *
|
3 |
+
from greedrl import Problem
|
4 |
+
|
5 |
+
features = [continuous_feature('task_demand'),
|
6 |
+
continuous_feature('worker_weight_limit'),
|
7 |
+
continuous_feature('distance_matrix'),
|
8 |
+
variable_feature('distance_this_to_task'),
|
9 |
+
variable_feature('distance_task_to_end')]
|
10 |
+
|
11 |
+
variables = [task_demand_now('task_demand'),
|
12 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
13 |
+
feature_variable('task_weight'),
|
14 |
+
task_variable('task_weight_this', feature='task_weight'),
|
15 |
+
worker_variable('worker_weight_limit'),
|
16 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
17 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True)]
|
18 |
+
|
19 |
+
|
20 |
+
class Constraint:
|
21 |
+
|
22 |
+
def do_task(self):
|
23 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
24 |
+
return torch.min(self.task_demand_this, worker_weight_limit // self.task_weight_this)
|
25 |
+
|
26 |
+
def mask_task(self):
|
27 |
+
# 已经完成的任务
|
28 |
+
mask = self.task_demand <= 0
|
29 |
+
# 车辆容量限制
|
30 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
31 |
+
# 至少要能装下一个单位的demand
|
32 |
+
mask |= self.task_weight > worker_weight_limit[:, None]
|
33 |
+
return mask
|
34 |
+
|
35 |
+
def finished(self):
|
36 |
+
return torch.all(self.task_demand <= 0, 1)
|
37 |
+
|
38 |
+
|
39 |
+
class Objective:
|
40 |
+
|
41 |
+
def step_worker_end(self):
|
42 |
+
return self.distance_last_to_this
|
43 |
+
|
44 |
+
def step_task(self):
|
45 |
+
return self.distance_last_to_this
|
46 |
+
|
47 |
+
|
48 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
49 |
+
assert batch_size == 1
|
50 |
+
|
51 |
+
NT = task_count
|
52 |
+
problem_list = []
|
53 |
+
for i in range(batch_count):
|
54 |
+
problem = Problem()
|
55 |
+
problem.id = i
|
56 |
+
|
57 |
+
problem.worker_weight_limit = [50]
|
58 |
+
|
59 |
+
problem.task_demand = torch.randint(1, 10, (NT,), dtype=torch.int64)
|
60 |
+
|
61 |
+
# 一个单位的task_demand的重量
|
62 |
+
problem.task_weight = torch.ones(NT, dtype=torch.int64)
|
63 |
+
|
64 |
+
loc = torch.rand(NT + 1, 2, dtype=torch.float32)
|
65 |
+
distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
|
66 |
+
problem.distance_matrix = distance_matrix.to(torch.int64)
|
67 |
+
|
68 |
+
problem.variables = variables
|
69 |
+
problem.constraint = Constraint
|
70 |
+
problem.objective = Objective
|
71 |
+
|
72 |
+
problem_list.append(problem)
|
73 |
+
|
74 |
+
return problem_list
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import sys
|
79 |
+
import os.path as osp
|
80 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
81 |
+
import runner
|
82 |
+
|
83 |
+
runner.run(make_problem)
|
examples/tsp/tsp.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from greedrl.feature import *
|
2 |
+
from greedrl.variable import *
|
3 |
+
from greedrl import Problem
|
4 |
+
|
5 |
+
features = [continuous_feature('task_location'),
|
6 |
+
variable_feature('distance_this_to_task'),
|
7 |
+
variable_feature('distance_task_to_end')]
|
8 |
+
|
9 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
10 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
11 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
12 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
13 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True),
|
14 |
+
edge_variable('distance_last_to_loop', feature='distance_matrix', last_to_loop=True)]
|
15 |
+
|
16 |
+
|
17 |
+
class Constraint:
|
18 |
+
|
19 |
+
def do_task(self):
|
20 |
+
return self.task_demand_this
|
21 |
+
|
22 |
+
def mask_task(self):
|
23 |
+
# 已经完成的任务
|
24 |
+
mask = self.task_demand_now <= 0
|
25 |
+
return mask
|
26 |
+
|
27 |
+
def mask_worker_end(self):
|
28 |
+
return torch.any(self.task_demand_now > 0, 1)
|
29 |
+
|
30 |
+
def finished(self):
|
31 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
32 |
+
|
33 |
+
|
34 |
+
class Objective:
|
35 |
+
|
36 |
+
def step_worker_end(self):
|
37 |
+
return self.distance_last_to_loop
|
38 |
+
|
39 |
+
def step_task(self):
|
40 |
+
return self.distance_last_to_this
|
41 |
+
|
42 |
+
|
43 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
44 |
+
NP = batch_size
|
45 |
+
NT = task_count
|
46 |
+
problem_list = []
|
47 |
+
for i in range(batch_count):
|
48 |
+
problem = Problem(True)
|
49 |
+
|
50 |
+
problem.task_demand = torch.ones(NP, NT, dtype=torch.int32)
|
51 |
+
|
52 |
+
loc = torch.rand(NP, NT + 1, 2, dtype=torch.float32)
|
53 |
+
problem.distance_matrix = torch.norm(loc[:, :, None, :] - loc[:, None, :, :], dim=3)
|
54 |
+
problem.distance_matrix[0, :] = 0
|
55 |
+
problem.distance_matrix[:, 0] = 0
|
56 |
+
|
57 |
+
problem.task_location = loc[:, 1:]
|
58 |
+
|
59 |
+
problem.features = features
|
60 |
+
problem.variables = variables
|
61 |
+
problem.constraint = Constraint
|
62 |
+
problem.objective = Objective
|
63 |
+
|
64 |
+
problem_list.append(problem)
|
65 |
+
return problem_list
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
import sys
|
70 |
+
import os.path as osp
|
71 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
72 |
+
import runner
|
73 |
+
|
74 |
+
runner.run(make_problem)
|
examples/vrptw/vrptw.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from greedrl import Problem
|
4 |
+
from greedrl.feature import *
|
5 |
+
from greedrl.variable import *
|
6 |
+
|
7 |
+
features = [continuous_feature('worker_weight_limit'),
|
8 |
+
continuous_feature('worker_ready_time'),
|
9 |
+
continuous_feature('worker_due_time'),
|
10 |
+
continuous_feature('worker_basic_cost'),
|
11 |
+
continuous_feature('worker_distance_cost'),
|
12 |
+
continuous_feature('task_demand'),
|
13 |
+
continuous_feature('task_weight'),
|
14 |
+
continuous_feature('task_ready_time'),
|
15 |
+
continuous_feature('task_due_time'),
|
16 |
+
continuous_feature('task_service_time'),
|
17 |
+
continuous_feature('distance_matrix')]
|
18 |
+
|
19 |
+
variables = [task_demand_now('task_demand_now', feature='task_demand'),
|
20 |
+
task_demand_now('task_demand_this', feature='task_demand', only_this=True),
|
21 |
+
feature_variable('task_weight'),
|
22 |
+
feature_variable('task_due_time'),
|
23 |
+
feature_variable('task_ready_time'),
|
24 |
+
feature_variable('task_service_time'),
|
25 |
+
worker_variable('worker_weight_limit'),
|
26 |
+
worker_variable('worker_due_time'),
|
27 |
+
worker_variable('worker_basic_cost'),
|
28 |
+
worker_variable('worker_distance_cost'),
|
29 |
+
worker_used_resource('worker_used_weight', task_require='task_weight'),
|
30 |
+
worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
|
31 |
+
'worker_ready_time'),
|
32 |
+
edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
|
33 |
+
edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
|
34 |
+
edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
|
35 |
+
|
36 |
+
|
37 |
+
class Constraint:
|
38 |
+
|
39 |
+
def do_task(self):
|
40 |
+
return self.task_demand_this
|
41 |
+
|
42 |
+
def mask_task(self):
|
43 |
+
# 已经完成的任务
|
44 |
+
mask = self.task_demand_now <= 0
|
45 |
+
# 车辆容量限制
|
46 |
+
worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
|
47 |
+
mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
|
48 |
+
|
49 |
+
worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
|
50 |
+
mask |= worker_used_time > self.task_due_time
|
51 |
+
|
52 |
+
worker_used_time = torch.max(worker_used_time, self.task_ready_time)
|
53 |
+
worker_used_time += self.task_service_time
|
54 |
+
worker_used_time += self.distance_task_to_end
|
55 |
+
mask |= worker_used_time > self.worker_due_time[:, None]
|
56 |
+
|
57 |
+
return mask
|
58 |
+
|
59 |
+
def finished(self):
|
60 |
+
return torch.all(self.task_demand_now <= 0, 1)
|
61 |
+
|
62 |
+
|
63 |
+
class Objective:
|
64 |
+
|
65 |
+
def step_worker_start(self):
|
66 |
+
return self.worker_basic_cost
|
67 |
+
|
68 |
+
def step_worker_end(self):
|
69 |
+
return self.distance_last_to_this * self.worker_distance_cost
|
70 |
+
|
71 |
+
def step_task(self):
|
72 |
+
return self.distance_last_to_this * self.worker_distance_cost
|
73 |
+
|
74 |
+
|
75 |
+
def make_problem_from_json(data):
|
76 |
+
if isinstance(data, str):
|
77 |
+
data = json.loads(data)
|
78 |
+
|
79 |
+
problem = Problem()
|
80 |
+
problem.worker_weight_limit = torch.tensor(data['worker_weight_limit'], dtype=torch.float32)
|
81 |
+
problem.worker_ready_time = torch.tensor(data['worker_ready_time'], dtype=torch.float32)
|
82 |
+
problem.worker_due_time = torch.tensor(data['worker_due_time'], dtype=torch.float32)
|
83 |
+
problem.worker_basic_cost = torch.tensor(data['worker_basic_cost'], dtype=torch.float32)
|
84 |
+
problem.worker_distance_cost = torch.tensor(data['worker_distance_cost'], dtype=torch.float32)
|
85 |
+
|
86 |
+
problem.task_demand = torch.tensor(data['task_demand'], dtype=torch.int32)
|
87 |
+
problem.task_weight = torch.tensor(data['task_weight'], dtype=torch.float32)
|
88 |
+
problem.task_ready_time = torch.tensor(data['task_ready_time'], dtype=torch.float32)
|
89 |
+
problem.task_due_time = torch.tensor(data['task_due_time'], dtype=torch.float32)
|
90 |
+
problem.task_service_time = torch.tensor(data['task_service_time'], dtype=torch.float32)
|
91 |
+
|
92 |
+
problem.distance_matrix = torch.tensor(data['distance_matrix'], dtype=torch.float32);
|
93 |
+
|
94 |
+
problem.features = features
|
95 |
+
problem.variables = variables
|
96 |
+
problem.constraint = Constraint
|
97 |
+
problem.objective = Objective
|
98 |
+
|
99 |
+
return problem
|
100 |
+
|
101 |
+
|
102 |
+
def make_problem(batch_count, batch_size=1, task_count=100):
|
103 |
+
assert batch_size == 1
|
104 |
+
|
105 |
+
NT = task_count
|
106 |
+
problem_list = []
|
107 |
+
for i in range(batch_count):
|
108 |
+
problem = Problem()
|
109 |
+
problem.id = i
|
110 |
+
|
111 |
+
problem.worker_weight_limit = torch.tensor([50], dtype=torch.float32)
|
112 |
+
problem.worker_ready_time = torch.tensor([0], dtype=torch.float32)
|
113 |
+
problem.worker_due_time = torch.tensor([1000000], dtype=torch.float32)
|
114 |
+
problem.worker_basic_cost = torch.tensor([100], dtype=torch.float32)
|
115 |
+
problem.worker_distance_cost = torch.tensor([1], dtype=torch.float32)
|
116 |
+
|
117 |
+
problem.task_demand = torch.randint(1, 10, (NT,), dtype=torch.int32)
|
118 |
+
problem.task_weight = torch.ones(NT, dtype=torch.float32)
|
119 |
+
problem.task_ready_time = torch.zeros(NT, dtype=torch.float32)
|
120 |
+
problem.task_due_time = torch.randint(10000, 100000, (NT,), dtype=torch.float32)
|
121 |
+
problem.task_service_time = torch.zeros(NT, dtype=torch.float32)
|
122 |
+
|
123 |
+
loc = torch.rand(NT + 1, 2, dtype=torch.float32)
|
124 |
+
problem.distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
|
125 |
+
problem_list.append(problem)
|
126 |
+
|
127 |
+
problem.features = features
|
128 |
+
problem.variables = variables
|
129 |
+
problem.constraint = Constraint
|
130 |
+
problem.objective = Objective
|
131 |
+
|
132 |
+
return problem_list
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
import sys
|
137 |
+
import os.path as osp
|
138 |
+
sys.path.append(osp.join(osp.dirname(__file__), '../'))
|
139 |
+
import runner
|
140 |
+
|
141 |
+
runner.run(make_problem)
|
greedrl/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.c
|
2 |
+
version.py
|
greedrl/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from .solver import Problem, Solution, Solver
|
4 |
+
from .const import GRL_WORKER_START, GRL_WORKER_END, GRL_TASK, GRL_FINISH
|
5 |
+
|
6 |
+
|
7 |
+
greedrl = sys.modules[__name__]
|
8 |
+
|
greedrl/agent.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
from collections import OrderedDict
|
5 |
+
from torch.utils.checkpoint import checkpoint
|
6 |
+
from .feature import *
|
7 |
+
from .pyenv import PyEnv
|
8 |
+
from .encode import Encode
|
9 |
+
from .decode import Decode
|
10 |
+
|
11 |
+
|
12 |
+
class Agent(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, nn_args):
|
15 |
+
super(Agent, self).__init__()
|
16 |
+
|
17 |
+
self.nn_args = nn_args
|
18 |
+
self.vars_dim = sum(nn_args['variable_dim'].values())
|
19 |
+
self.steps_ratio = nn_args.setdefault('decode_steps_ratio', 1.0);
|
20 |
+
|
21 |
+
logit_clips = nn_args.setdefault('decode_logit_clips', 10.0);
|
22 |
+
if isinstance(logit_clips, str):
|
23 |
+
self.logit_clips = [float(v) for v in logit_clips.split(',')]
|
24 |
+
else:
|
25 |
+
self.logit_clips = [float(logit_clips)]
|
26 |
+
|
27 |
+
self.nn_encode = Encode(nn_args)
|
28 |
+
self.nn_decode = Decode(nn_args)
|
29 |
+
|
30 |
+
def nn_args_dict(self):
|
31 |
+
return self.nn_args
|
32 |
+
|
33 |
+
def forward(self, problem, batch_size, greedy=False, solution=None, memopt=0):
|
34 |
+
X, K, V = self.nn_encode(problem.feats, problem.batch_size,
|
35 |
+
problem.worker_num, problem.task_num, memopt)
|
36 |
+
|
37 |
+
return self.interact(problem, X, K, V, batch_size, greedy, solution, memopt)
|
38 |
+
|
39 |
+
def interact(self, problem, X, K, V, batch_size, greedy, solution, memopt):
|
40 |
+
NP = problem.batch_size
|
41 |
+
NW = problem.worker_num
|
42 |
+
NT = problem.task_num
|
43 |
+
|
44 |
+
sample_num = batch_size // NP
|
45 |
+
assert sample_num > 0 and batch_size % NP == 0
|
46 |
+
|
47 |
+
MyEnv = problem.environment
|
48 |
+
if MyEnv is None:
|
49 |
+
env = PyEnv(problem, batch_size, sample_num, self.nn_args)
|
50 |
+
else:
|
51 |
+
env = MyEnv(str(problem.device), problem.feats, batch_size,
|
52 |
+
sample_num, problem.worker_num, problem.task_num)
|
53 |
+
|
54 |
+
query = X.new_zeros(batch_size, X.size(-1))
|
55 |
+
state1 = X.new_zeros(batch_size, X.size(-1))
|
56 |
+
state2 = X.new_zeros(batch_size, X.size(-1))
|
57 |
+
|
58 |
+
p_list = []
|
59 |
+
NULL = X.new_ones(0)
|
60 |
+
p_index = torch.div(torch.arange(batch_size, device=X.device), sample_num, rounding_mode='trunc') # torch.arange(batch_size, device=X.device) // sample_num
|
61 |
+
if solution is not None:
|
62 |
+
solution = solution[:, :, 0:2].to(torch.int64).permute(1, 0, 2)
|
63 |
+
assert torch.all(solution >= 0) and solution.size(1) == batch_size
|
64 |
+
offset = torch.tensor([0, NW, NW + NW, NW + NW + NT], device=X.device)
|
65 |
+
chosen_list = solution[:, :, 1] + offset[solution[:, :, 0]]
|
66 |
+
|
67 |
+
mode = 0
|
68 |
+
sample_p = torch.rand(batch_size, device=X.device)
|
69 |
+
for chosen in chosen_list:
|
70 |
+
env_time = env.time()
|
71 |
+
clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
|
72 |
+
varfeat = env.make_feat() if self.vars_dim > 0 else NULL
|
73 |
+
state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
|
74 |
+
varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
|
75 |
+
query = X[p_index, chosen]
|
76 |
+
p_list.append(chosen_p)
|
77 |
+
env.step(chosen)
|
78 |
+
|
79 |
+
assert env.all_finished(), 'not all finished!'
|
80 |
+
else:
|
81 |
+
mode = 1 if greedy else 2
|
82 |
+
min_env_time = int(self.steps_ratio * NT)
|
83 |
+
R = torch.rand(NT * 2, batch_size, device=X.device)
|
84 |
+
while True:
|
85 |
+
env_time = env.time()
|
86 |
+
if env_time > min_env_time and env_time % 3 == 0 and env.all_finished():
|
87 |
+
break
|
88 |
+
|
89 |
+
clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
|
90 |
+
sample_p = R[env_time % R.size(0)]
|
91 |
+
chosen = X.new_empty(batch_size, dtype=torch.int64)
|
92 |
+
varfeat = env.make_feat() if self.vars_dim > 0 else NULL
|
93 |
+
state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
|
94 |
+
varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
|
95 |
+
query = X[p_index, chosen]
|
96 |
+
p_list.append(chosen_p)
|
97 |
+
env.step(chosen)
|
98 |
+
|
99 |
+
env.finalize()
|
100 |
+
return env, torch.stack(p_list, 1)
|
101 |
+
|
102 |
+
def decode(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt):
|
103 |
+
run_fn = self.decode_fn(clip, mode, memopt)
|
104 |
+
if self.training and memopt > 3:
|
105 |
+
return checkpoint(run_fn, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)
|
106 |
+
else:
|
107 |
+
return run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)
|
108 |
+
|
109 |
+
def decode_fn(self, clip, mode, memopt):
|
110 |
+
memopt = 0 if memopt > 3 else memopt
|
111 |
+
|
112 |
+
def run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p):
|
113 |
+
return self.nn_decode(X, K, V, query, state1, state2,
|
114 |
+
varfeat, mask, chosen, sample_p, clip, mode, memopt)
|
115 |
+
|
116 |
+
return run_fn
|
117 |
+
|
118 |
+
|
119 |
+
def parse_nn_args(problem, nn_args):
|
120 |
+
worker_dim = OrderedDict()
|
121 |
+
task_dim = OrderedDict()
|
122 |
+
edge_dim = OrderedDict()
|
123 |
+
variable_dim = OrderedDict()
|
124 |
+
embed_dict = OrderedDict()
|
125 |
+
|
126 |
+
def set_dim_by_name(name, k, dim):
|
127 |
+
if name.startswith("worker_task_"):
|
128 |
+
edge_dim[k] = dim
|
129 |
+
elif name.startswith("worker_"):
|
130 |
+
worker_dim[k] = dim
|
131 |
+
elif name.startswith("task_"):
|
132 |
+
task_dim[k] = dim
|
133 |
+
elif name.endswith("_matrix"):
|
134 |
+
edge_dim[k] = dim
|
135 |
+
else:
|
136 |
+
raise Exception("attribute can't be feature: {}".format(k))
|
137 |
+
|
138 |
+
feature_dict = make_feat_dict(problem)
|
139 |
+
variables = [var(problem, problem.batch_size, 1) for var in problem.variables]
|
140 |
+
variable_dict = dict([(var.name, var) for var in variables])
|
141 |
+
for k, f in feature_dict.items():
|
142 |
+
if isinstance(f, VariableFeature):
|
143 |
+
var = variable_dict[f.name]
|
144 |
+
assert hasattr(var, 'make_feat'), \
|
145 |
+
"{} cann't be variable feature, name:{}".format(type(var).__name__, k)
|
146 |
+
v = var.make_feat()
|
147 |
+
if v.dim() == 2:
|
148 |
+
variable_dim[k] = 1
|
149 |
+
else:
|
150 |
+
variable_dim[k] = v.size(-1)
|
151 |
+
elif isinstance(f, SparseLocalFeature):
|
152 |
+
edge_dim[k] = 1
|
153 |
+
set_dim_by_name(f.value, k, 1)
|
154 |
+
elif isinstance(f, LocalFeature):
|
155 |
+
edge_dim[k] = 1
|
156 |
+
set_dim_by_name(f.name, k, 1)
|
157 |
+
elif isinstance(f, LocalCategory):
|
158 |
+
edge_dim[k] = 1
|
159 |
+
elif isinstance(f, GlobalCategory):
|
160 |
+
set_dim_by_name(f.name, k, nn_args.setdefault('encode_hidden_dim', 128))
|
161 |
+
embed_dict[k] = f.size
|
162 |
+
elif isinstance(f, ContinuousFeature):
|
163 |
+
v = problem.feats[k]
|
164 |
+
if k.startswith("worker_task_") or k.endswith("_matrix"):
|
165 |
+
simple_dim = 3
|
166 |
+
else:
|
167 |
+
simple_dim = 2
|
168 |
+
|
169 |
+
if v.dim() == simple_dim:
|
170 |
+
set_dim_by_name(f.name, k, 1)
|
171 |
+
else:
|
172 |
+
set_dim_by_name(f.name, k, v.size(-1))
|
173 |
+
else:
|
174 |
+
raise Exception("unsupported feature type: {}".format(type(f)))
|
175 |
+
|
176 |
+
nn_args['worker_dim'] = worker_dim
|
177 |
+
nn_args['task_dim'] = task_dim
|
178 |
+
nn_args['edge_dim'] = edge_dim
|
179 |
+
nn_args['variable_dim'] = variable_dim
|
180 |
+
nn_args['embed_dict'] = embed_dict
|
181 |
+
nn_args['feature_dict'] = feature_dict
|
182 |
+
return nn_args
|
183 |
+
|
184 |
+
|
185 |
+
def make_feat_dict(problem):
|
186 |
+
feature_dict = OrderedDict()
|
187 |
+
|
188 |
+
def add(k, f):
|
189 |
+
_f = feature_dict.get(k)
|
190 |
+
if _f is None or _f == f:
|
191 |
+
feature_dict[k] = f
|
192 |
+
else:
|
193 |
+
"duplicated feature, name: {}, feature1: {}, feature2: {}".format(k, _f, f)
|
194 |
+
|
195 |
+
for f in problem.features:
|
196 |
+
if isinstance(f, VariableFeature):
|
197 |
+
add(':'.join(['var', f.name]), f)
|
198 |
+
elif isinstance(f, SparseLocalFeature):
|
199 |
+
add(':'.join([f.index, f.value]), f)
|
200 |
+
else:
|
201 |
+
add(f.name, f)
|
202 |
+
|
203 |
+
return feature_dict
|
greedrl/const.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
GRL_WORKER_START = 0
|
3 |
+
GRL_WORKER_END = 1
|
4 |
+
GRL_TASK = 2
|
5 |
+
GRL_FINISH = 3
|
6 |
+
|
7 |
+
|
greedrl/decode.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.checkpoint import checkpoint
|
7 |
+
|
8 |
+
|
9 |
+
class MultiHeadAttention(nn.Module):
|
10 |
+
def __init__(self, heads, hidden_dim):
|
11 |
+
super(MultiHeadAttention, self).__init__()
|
12 |
+
|
13 |
+
assert hidden_dim % heads == 0
|
14 |
+
|
15 |
+
self.heads = heads
|
16 |
+
head_dim = hidden_dim // heads
|
17 |
+
self.alpha = 1 / math.sqrt(head_dim)
|
18 |
+
|
19 |
+
self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
20 |
+
self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
|
21 |
+
|
22 |
+
for param in self.parameters():
|
23 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
24 |
+
param.data.uniform_(-stdv, stdv)
|
25 |
+
|
26 |
+
def forward(self, q, K, V, mask):
|
27 |
+
batch_size, query_num, hidden_dim = q.size()
|
28 |
+
|
29 |
+
size = (self.heads, batch_size, query_num, -1)
|
30 |
+
|
31 |
+
q = q.reshape(-1, hidden_dim)
|
32 |
+
Q = torch.matmul(q, self.nn_Q).view(size)
|
33 |
+
|
34 |
+
value_num = V.size(2)
|
35 |
+
heads_batch = self.heads * batch_size
|
36 |
+
Q = Q.view(heads_batch, query_num, -1)
|
37 |
+
K = K.view(heads_batch, value_num, -1).transpose(1, 2)
|
38 |
+
|
39 |
+
S = masked_tensor(mask, self.heads)
|
40 |
+
S = S.view(heads_batch, query_num, value_num)
|
41 |
+
S.baddbmm_(Q, K, alpha=self.alpha)
|
42 |
+
S = S.view(self.heads, batch_size, query_num, value_num)
|
43 |
+
|
44 |
+
S = F.softmax(S, dim=-1)
|
45 |
+
|
46 |
+
x = torch.matmul(S, V).permute(1, 2, 0, 3)
|
47 |
+
x = x.reshape(batch_size, query_num, -1)
|
48 |
+
x = torch.matmul(x, self.nn_O)
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class Decode(nn.Module):
|
53 |
+
|
54 |
+
def __init__(self, nn_args):
|
55 |
+
super(Decode, self).__init__()
|
56 |
+
|
57 |
+
self.nn_args = nn_args
|
58 |
+
|
59 |
+
heads = nn_args['decode_atten_heads']
|
60 |
+
hidden_dim = nn_args['decode_hidden_dim']
|
61 |
+
|
62 |
+
self.heads = heads
|
63 |
+
self.alpha = 1 / math.sqrt(hidden_dim)
|
64 |
+
|
65 |
+
if heads > 0:
|
66 |
+
assert hidden_dim % heads == 0
|
67 |
+
head_dim = hidden_dim // heads
|
68 |
+
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
69 |
+
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
70 |
+
self.nn_mha = MultiHeadAttention(heads, hidden_dim)
|
71 |
+
|
72 |
+
decode_rnn = nn_args.setdefault('decode_rnn', 'LSTM')
|
73 |
+
assert decode_rnn in ('GRU', 'LSTM', 'NONE')
|
74 |
+
if decode_rnn == 'GRU':
|
75 |
+
self.nn_rnn_cell = nn.GRUCell(hidden_dim, hidden_dim)
|
76 |
+
elif decode_rnn == 'LSTM':
|
77 |
+
self.nn_rnn_cell = nn.LSTMCell(hidden_dim, hidden_dim)
|
78 |
+
else:
|
79 |
+
self.nn_rnn_cell = None
|
80 |
+
|
81 |
+
self.vars_dim = sum(nn_args['variable_dim'].values())
|
82 |
+
if self.vars_dim > 0:
|
83 |
+
atten_type = nn_args.setdefault('decode_atten_type', 'add')
|
84 |
+
assert atten_type == 'add', "must be addition attention when vars_dim > 0, {}".format(atten_type)
|
85 |
+
self.nn_A = nn.Parameter(torch.Tensor(self.vars_dim, hidden_dim))
|
86 |
+
self.nn_B = nn.Parameter(torch.Tensor(hidden_dim))
|
87 |
+
else:
|
88 |
+
atten_type = nn_args.setdefault('decode_atten_type', 'prod')
|
89 |
+
|
90 |
+
if atten_type == 'add':
|
91 |
+
self.nn_W = nn.Parameter(torch.Tensor(hidden_dim))
|
92 |
+
else:
|
93 |
+
self.nn_W = None
|
94 |
+
|
95 |
+
for param in self.parameters():
|
96 |
+
stdv = 1 / math.sqrt(param.size(-1))
|
97 |
+
param.data.uniform_(-stdv, stdv)
|
98 |
+
|
99 |
+
def forward(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt=0):
|
100 |
+
if self.training and memopt > 2:
|
101 |
+
state1, state2 = checkpoint(self.rnn_step, query, state1, state2)
|
102 |
+
else:
|
103 |
+
state1, state2 = self.rnn_step(query, state1, state2)
|
104 |
+
|
105 |
+
query = state1
|
106 |
+
NP = X.size(0)
|
107 |
+
NR = query.size(0) // NP
|
108 |
+
batch_size = query.size(0)
|
109 |
+
if self.heads > 0:
|
110 |
+
query = query.view(NP, NR, -1)
|
111 |
+
if self.training and memopt > 1:
|
112 |
+
query = checkpoint(self.nn_mha, query, K, V, mask)
|
113 |
+
else:
|
114 |
+
query = self.nn_mha(query, K, V, mask)
|
115 |
+
|
116 |
+
query = query.view(batch_size, -1)
|
117 |
+
|
118 |
+
if self.nn_W is None:
|
119 |
+
query = query.view(NP, NR, -1)
|
120 |
+
logit = masked_tensor(mask, 1)
|
121 |
+
logit = logit.view(NP, NR, -1)
|
122 |
+
X = X.permute(0, 2, 1)
|
123 |
+
logit.baddbmm_(query, X, alpha=self.alpha)
|
124 |
+
logit = logit.view(batch_size, -1)
|
125 |
+
else:
|
126 |
+
if self.training and self.vars_dim > 0 and memopt > 0:
|
127 |
+
logit = checkpoint(self.atten, query, X, varfeat, mask)
|
128 |
+
else:
|
129 |
+
logit = self.atten(query, X, varfeat, mask)
|
130 |
+
|
131 |
+
chosen_p = choose(logit, chosen, sample_p, clip, mode)
|
132 |
+
return state1, state2, chosen_p
|
133 |
+
|
134 |
+
def rnn_step(self, query, state1, state2):
|
135 |
+
if isinstance(self.nn_rnn_cell, nn.GRUCell):
|
136 |
+
state1 = self.nn_rnn_cell(query, state1)
|
137 |
+
elif isinstance(self.nn_rnn_cell, nn.LSTMCell):
|
138 |
+
state1, state2 = self.nn_rnn_cell(query, (state1, state2))
|
139 |
+
return state1, state2
|
140 |
+
|
141 |
+
def atten(self, query, keyvalue, varfeat, mask):
|
142 |
+
if self.vars_dim > 0:
|
143 |
+
varfeat = vfaddmm(varfeat, mask, self.nn_A, self.nn_B)
|
144 |
+
return atten(query, keyvalue, varfeat, mask, self.nn_W)
|
145 |
+
|
146 |
+
|
147 |
+
def choose(logit, chosen, sample_p, clip, mode):
|
148 |
+
mask = logit == -math.inf
|
149 |
+
logit = torch.tanh(logit) * clip
|
150 |
+
logit[mask] = -math.inf
|
151 |
+
|
152 |
+
if mode == 0:
|
153 |
+
pass
|
154 |
+
elif mode == 1:
|
155 |
+
chosen[:] = logit.argmax(1)
|
156 |
+
elif mode == 2:
|
157 |
+
p = logit.exp()
|
158 |
+
chosen[:] = torch.multinomial(p, 1).squeeze(1)
|
159 |
+
else:
|
160 |
+
raise Exception()
|
161 |
+
|
162 |
+
logp = logit.log_softmax(1)
|
163 |
+
logp = logp.gather(1, chosen[:, None])
|
164 |
+
logp = logp.squeeze(1)
|
165 |
+
return logp
|
166 |
+
|
167 |
+
|
168 |
+
def atten(query, keyvalue, varfeat, mask, weight):
|
169 |
+
batch_size = query.size(0)
|
170 |
+
NP, NK, ND = keyvalue.size()
|
171 |
+
|
172 |
+
query = query.view(NP, -1, 1, ND)
|
173 |
+
varfeat = varfeat.view(NP, -1, NK, ND)
|
174 |
+
keyvalue = keyvalue[:, None, :, :]
|
175 |
+
keyvalue = keyvalue + varfeat + query
|
176 |
+
keyvalue = torch.tanh(keyvalue)
|
177 |
+
keyvalue = keyvalue.view(-1, ND)
|
178 |
+
|
179 |
+
logit = masked_tensor(mask, 1).view(-1)
|
180 |
+
logit.addmv_(keyvalue, weight)
|
181 |
+
return logit.view(batch_size, -1)
|
182 |
+
|
183 |
+
|
184 |
+
def masked_tensor(mask, heads):
|
185 |
+
size = list(mask.size())
|
186 |
+
size.insert(0, heads)
|
187 |
+
mask = mask[None].expand(size)
|
188 |
+
result = mask.new_zeros(size, dtype=torch.float32)
|
189 |
+
result[mask] = -math.inf
|
190 |
+
return result
|
191 |
+
|
192 |
+
|
193 |
+
def vfaddmm(varfeat, mask, A, B):
|
194 |
+
varfeat = varfeat.permute(0, 2, 1)
|
195 |
+
return F.linear(varfeat, A.permute(1, 0), B)
|
196 |
+
|
greedrl/dense.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from .utils import get_act
|
4 |
+
from .norm import Norm1D, Norm2D
|
5 |
+
|
6 |
+
|
7 |
+
class Dense(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, input_dim, output_dim, bias=True, norm1d='none', norm2d='none', act='none'):
|
10 |
+
super(Dense, self).__init__()
|
11 |
+
assert norm1d == 'none' or norm2d == 'none', "one of [norm1d, norm2d] must be none"
|
12 |
+
|
13 |
+
if norm1d != 'none':
|
14 |
+
self.nn_norm = Norm1D(input_dim, norm1d)
|
15 |
+
elif norm2d != 'none':
|
16 |
+
self.nn_norm = Norm2D(input_dim, norm2d)
|
17 |
+
else:
|
18 |
+
self.nn_norm = None
|
19 |
+
|
20 |
+
self.nn_act = get_act(act)
|
21 |
+
self.nn_linear = nn.Linear(input_dim, output_dim, bias)
|
22 |
+
|
23 |
+
def weight(self):
|
24 |
+
return self.nn_linear.weight
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
if self.nn_norm is not None:
|
28 |
+
x = self.nn_norm(x)
|
29 |
+
x = self.nn_act(x)
|
30 |
+
x = self.nn_linear(x)
|
31 |
+
return x
|
greedrl/encode.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.checkpoint import checkpoint
|
7 |
+
from .norm import Norm1D, Norm2D
|
8 |
+
from .dense import Dense
|
9 |
+
from .utils import repeat
|
10 |
+
from .feature import *
|
11 |
+
|
12 |
+
|
13 |
+
class MultiHeadAttention(nn.Module):
|
14 |
+
def __init__(self, heads, hidden_dim):
|
15 |
+
super(MultiHeadAttention, self).__init__()
|
16 |
+
|
17 |
+
assert hidden_dim % heads == 0
|
18 |
+
|
19 |
+
self.heads = heads
|
20 |
+
head_dim = hidden_dim // heads
|
21 |
+
self.alpha = 1 / math.sqrt(head_dim)
|
22 |
+
|
23 |
+
self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
24 |
+
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
25 |
+
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
26 |
+
self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
|
27 |
+
|
28 |
+
for param in self.parameters():
|
29 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
30 |
+
param.data.uniform_(-stdv, stdv)
|
31 |
+
|
32 |
+
def forward(self, x, edge):
|
33 |
+
batch_size, item_num, hidden_dim = x.size()
|
34 |
+
size = (self.heads, batch_size, item_num, -1)
|
35 |
+
|
36 |
+
x = x.reshape(-1, hidden_dim)
|
37 |
+
Q = torch.matmul(x, self.nn_Q).view(size)
|
38 |
+
K = torch.matmul(x, self.nn_K).view(size)
|
39 |
+
V = torch.matmul(x, self.nn_V).view(size)
|
40 |
+
|
41 |
+
heads_batch = self.heads * batch_size
|
42 |
+
Q = Q.view(heads_batch, item_num, -1)
|
43 |
+
K = K.view(heads_batch, item_num, -1).transpose(1, 2)
|
44 |
+
|
45 |
+
if edge is not None:
|
46 |
+
S = edge.view(heads_batch, item_num, item_num)
|
47 |
+
S = S.baddbmm(Q, K, alpha=self.alpha)
|
48 |
+
else:
|
49 |
+
S = Q.new_zeros(heads_batch, item_num, item_num)
|
50 |
+
S = S.baddbmm_(Q, K, alpha=self.alpha)
|
51 |
+
|
52 |
+
S = S.view(self.heads, batch_size, item_num, item_num)
|
53 |
+
|
54 |
+
S = F.softmax(S, dim=-1)
|
55 |
+
|
56 |
+
x = torch.matmul(S, V).permute(1, 2, 0, 3)
|
57 |
+
x = x.reshape(batch_size, item_num, -1)
|
58 |
+
x = torch.matmul(x, self.nn_O)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class Encode(nn.Module):
|
63 |
+
def __init__(self, nn_args):
|
64 |
+
super(Encode, self).__init__()
|
65 |
+
|
66 |
+
self.nn_args = nn_args
|
67 |
+
self.worker_dim = nn_args['worker_dim']
|
68 |
+
self.task_dim = nn_args['task_dim']
|
69 |
+
self.edge_dim = nn_args['edge_dim']
|
70 |
+
|
71 |
+
self.embed_dict = nn_args['embed_dict']
|
72 |
+
self.feature_dict = nn_args['feature_dict']
|
73 |
+
|
74 |
+
layers = nn_args.setdefault('encode_layers', 3)
|
75 |
+
heads = nn_args.setdefault('encode_atten_heads', 8)
|
76 |
+
norm = nn_args.setdefault('encode_norm', 'instance')
|
77 |
+
hidden_dim = nn_args.setdefault('encode_hidden_dim', 128)
|
78 |
+
output_dim = nn_args.setdefault('decode_hidden_dim', 128)
|
79 |
+
output_heads = nn_args.setdefault('decode_atten_heads', 0)
|
80 |
+
|
81 |
+
self.heads = heads
|
82 |
+
self.layers = layers
|
83 |
+
|
84 |
+
worker_dim = max(1, sum(self.worker_dim.values()))
|
85 |
+
task_dim = max(1, sum(self.task_dim.values()))
|
86 |
+
|
87 |
+
self.nn_dense_worker_start = Dense(worker_dim, hidden_dim)
|
88 |
+
self.nn_dense_worker_end = Dense(worker_dim, hidden_dim)
|
89 |
+
self.nn_dense_task = Dense(task_dim, hidden_dim)
|
90 |
+
|
91 |
+
self.nn_norm_worker_task = Norm1D(hidden_dim, norm, True)
|
92 |
+
|
93 |
+
if len(self.edge_dim) > 0:
|
94 |
+
edge_dim = sum(self.edge_dim.values())
|
95 |
+
self.nn_dense_edge = Dense(edge_dim, heads)
|
96 |
+
self.nn_norm_edge = Norm2D(heads, norm, True)
|
97 |
+
|
98 |
+
nn_embed_dict = {}
|
99 |
+
for k, v in self.embed_dict.items():
|
100 |
+
nn_embed_dict[k] = nn.Embedding(v, hidden_dim)
|
101 |
+
self.nn_embed_dict = nn.ModuleDict(nn_embed_dict)
|
102 |
+
|
103 |
+
self.nn_attens = nn.ModuleList()
|
104 |
+
self.nn_denses = nn.ModuleList()
|
105 |
+
self.nn_norms1 = nn.ModuleList()
|
106 |
+
self.nn_norms2 = nn.ModuleList()
|
107 |
+
for i in range(layers):
|
108 |
+
self.nn_attens.append(MultiHeadAttention(heads, hidden_dim))
|
109 |
+
self.nn_denses.append(nn.Sequential(
|
110 |
+
Dense(hidden_dim, hidden_dim * 4),
|
111 |
+
Dense(hidden_dim * 4, hidden_dim, act='relu'),
|
112 |
+
))
|
113 |
+
self.nn_norms1.append(Norm1D(hidden_dim, norm, True))
|
114 |
+
self.nn_norms2.append(Norm1D(hidden_dim, norm, True))
|
115 |
+
|
116 |
+
self.nn_finish = nn.Parameter(torch.Tensor(1, 1, hidden_dim))
|
117 |
+
|
118 |
+
if output_dim != hidden_dim:
|
119 |
+
self.nn_X = nn.Parameter(torch.Tensor(hidden_dim, output_dim))
|
120 |
+
else:
|
121 |
+
self.nn_X = None
|
122 |
+
|
123 |
+
if output_heads > 0:
|
124 |
+
assert output_dim % output_heads == 0
|
125 |
+
head_dim = output_dim // output_heads
|
126 |
+
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
127 |
+
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
|
128 |
+
else:
|
129 |
+
self.nn_K = None
|
130 |
+
self.nn_V = None
|
131 |
+
|
132 |
+
for param in self.parameters():
|
133 |
+
stdv = 1 / math.sqrt(param.size(-1))
|
134 |
+
param.data.uniform_(-stdv, stdv)
|
135 |
+
|
136 |
+
def forward(self, problem, batch_size, worker_num, task_num, memopt=0):
|
137 |
+
worker_start, worker_end = self.encode_worker(problem, batch_size, worker_num)
|
138 |
+
task = self.encode_task(problem, batch_size, task_num)
|
139 |
+
X = torch.cat([worker_start, worker_end, task], 1)
|
140 |
+
X = self.nn_norm_worker_task(X)
|
141 |
+
|
142 |
+
if len(self.edge_dim) > 0:
|
143 |
+
edge = self.encode_edge(problem, batch_size, worker_num, task_num)
|
144 |
+
edge = self.nn_norm_edge(edge)
|
145 |
+
edge = edge.permute(3, 0, 1, 2).contiguous()
|
146 |
+
else:
|
147 |
+
edge = None
|
148 |
+
|
149 |
+
#transformer encoding
|
150 |
+
for i in range(self.layers):
|
151 |
+
X = self.encode_layer(X, edge, i, memopt)
|
152 |
+
|
153 |
+
finish = repeat(self.nn_finish, X.size(0))
|
154 |
+
X = torch.cat([X, finish], 1)
|
155 |
+
if self.nn_X is not None:
|
156 |
+
X = torch.matmul(X, self.nn_X)
|
157 |
+
|
158 |
+
if self.nn_K is not None:
|
159 |
+
batch_size, item_num, hidden_dim = X.size()
|
160 |
+
size = (self.heads, batch_size, item_num, -1)
|
161 |
+
X2 = X.reshape(-1, hidden_dim)
|
162 |
+
K = torch.matmul(X2, self.nn_K).view(size)
|
163 |
+
V = torch.matmul(X2, self.nn_V).view(size)
|
164 |
+
else:
|
165 |
+
K = torch.ones(0)
|
166 |
+
V = torch.ones(0)
|
167 |
+
return X, K, V
|
168 |
+
|
169 |
+
def encode_layer(self, X, edge, i, memopt):
|
170 |
+
run_fn = self.encode_layer_fn(i, memopt)
|
171 |
+
if self.training and memopt > 6:
|
172 |
+
return checkpoint(run_fn, X, edge)
|
173 |
+
else:
|
174 |
+
return run_fn(X, edge)
|
175 |
+
|
176 |
+
def encode_layer_fn(self, i, memopt):
|
177 |
+
def run_fn(X, edge):
|
178 |
+
if self.training and memopt == 6:
|
179 |
+
X = X + checkpoint(self.nn_attens[i], X, edge)
|
180 |
+
else:
|
181 |
+
X = X + self.nn_attens[i](X, edge)
|
182 |
+
X = self.nn_norms1[i](X)
|
183 |
+
|
184 |
+
X = X + self.nn_denses[i](X)
|
185 |
+
X = self.nn_norms2[i](X)
|
186 |
+
return X
|
187 |
+
|
188 |
+
return run_fn
|
189 |
+
|
190 |
+
def encode_worker(self, problem, batch_size, worker_num):
|
191 |
+
feature_list = []
|
192 |
+
for k, dim in self.worker_dim.items():
|
193 |
+
f = self.feature_dict.get(k)
|
194 |
+
if isinstance(f, GlobalCategory):
|
195 |
+
v = problem[f.name]
|
196 |
+
v = self.nn_embed_dict[k](v.long())
|
197 |
+
elif isinstance(f, ContinuousFeature):
|
198 |
+
v = problem[f.name]
|
199 |
+
else:
|
200 |
+
raise Exception("unsupported feature type: {}".format(type(f)))
|
201 |
+
|
202 |
+
if v.dim() == 2:
|
203 |
+
v = v[:, :, None]
|
204 |
+
|
205 |
+
assert dim == v.size(-1), \
|
206 |
+
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
|
207 |
+
|
208 |
+
feature_list.append(v.float())
|
209 |
+
|
210 |
+
if feature_list:
|
211 |
+
x = torch.cat(feature_list, 2)
|
212 |
+
else:
|
213 |
+
x = self.nn_finish.new_ones(batch_size, worker_num, 1)
|
214 |
+
return self.nn_dense_worker_start(x), self.nn_dense_worker_end(x)
|
215 |
+
|
216 |
+
def encode_task(self, problem, batch_size, task_num):
|
217 |
+
feature_list = []
|
218 |
+
for k, dim in self.task_dim.items():
|
219 |
+
f = self.feature_dict.get(k)
|
220 |
+
if isinstance(f, SparseLocalFeature):
|
221 |
+
v = problem[f.value]
|
222 |
+
assert v.dim() == 3, \
|
223 |
+
"sparse local feature's dimension must 2, feature:{}".format(k)
|
224 |
+
v = v.clamp(0, 1).sum(2, dtype=v.dtype)
|
225 |
+
elif isinstance(f, GlobalCategory):
|
226 |
+
v = problem[f.name]
|
227 |
+
v = self.nn_embed_dict[k](v.long())
|
228 |
+
elif isinstance(f, LocalFeature):
|
229 |
+
v = problem[f.name]
|
230 |
+
assert v.dim() == 3, \
|
231 |
+
"local feature's dimension must 2, feature:{}".format(k)
|
232 |
+
v = v.clamp(0, 1).sum(2, dtype=v.dtype)
|
233 |
+
elif isinstance(f, ContinuousFeature):
|
234 |
+
v = problem[f.name]
|
235 |
+
else:
|
236 |
+
raise Exception("unsupported feature type: {}".format(type(f)))
|
237 |
+
|
238 |
+
if v.dim() == 2:
|
239 |
+
v = v[:, :, None]
|
240 |
+
|
241 |
+
assert dim == v.size(-1), \
|
242 |
+
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
|
243 |
+
|
244 |
+
feature_list.append(v.float())
|
245 |
+
|
246 |
+
if feature_list:
|
247 |
+
x = torch.cat(feature_list, 2)
|
248 |
+
else:
|
249 |
+
x = self.nn_finish.new_ones(batch_size, task_num, 1)
|
250 |
+
return self.nn_dense_task(x)
|
251 |
+
|
252 |
+
def encode_edge(self, problem, batch_size, worker_num, task_num):
|
253 |
+
NP = batch_size
|
254 |
+
NW = worker_num
|
255 |
+
NT = task_num
|
256 |
+
NWW = NW + NW
|
257 |
+
feature_list = []
|
258 |
+
for k, dim in self.edge_dim.items():
|
259 |
+
f = self.feature_dict.get(k)
|
260 |
+
if isinstance(f, LocalCategory):
|
261 |
+
assert f.name.startswith("task_")
|
262 |
+
|
263 |
+
v = problem[k]
|
264 |
+
v1 = v[:, :, None]
|
265 |
+
v2 = v[:, None, :]
|
266 |
+
|
267 |
+
v = torch.zeros(NP, NWW + NT, NWW + NT,
|
268 |
+
dtype=v.dtype, device=v.device)
|
269 |
+
v[:, NWW:, NWW:] = ((v1 == v2) & (v1 >= 0))
|
270 |
+
elif isinstance(f, LocalFeature):
|
271 |
+
assert f.name.startswith("task_")
|
272 |
+
|
273 |
+
v = problem[k].float()
|
274 |
+
dot_product = torch.matmul(v, v.transpose(-1, -2))
|
275 |
+
v_norm = v.norm(dim=2) + 1e-10
|
276 |
+
v1_norm = v_norm[:, :, None]
|
277 |
+
v2_norm = v_norm[:, None, :]
|
278 |
+
|
279 |
+
v = torch.zeros(NP, NWW + NT, NWW + NT,
|
280 |
+
dtype=v.dtype, device=v.device)
|
281 |
+
v[:, NWW:, NWW:] = dot_product / v1_norm / v2_norm
|
282 |
+
elif isinstance(f, SparseLocalFeature):
|
283 |
+
assert NP == 1
|
284 |
+
assert f.index.startswith("task_")
|
285 |
+
assert f.value.startswith("task_")
|
286 |
+
|
287 |
+
index = problem[f.index]
|
288 |
+
value = problem[f.value].float()
|
289 |
+
|
290 |
+
NV = index.max().item() + 1
|
291 |
+
spv = value.reshape(-1).tolist()
|
292 |
+
spi = index.reshape(-1).tolist()
|
293 |
+
|
294 |
+
device = value.device
|
295 |
+
spj = torch.arange(NT, device=device)
|
296 |
+
spj = spj[:, None].expand_as(index)
|
297 |
+
spj = spj.reshape(-1).tolist()
|
298 |
+
|
299 |
+
value1 = torch.sparse_coo_tensor([spj, spi], spv, (NT, NV), device=device)
|
300 |
+
value2 = torch.sparse_coo_tensor([spi, spj], spv, (NV, NT), device=device)
|
301 |
+
|
302 |
+
value1 = value1.coalesce()
|
303 |
+
value2 = value2.coalesce()
|
304 |
+
cosine = torch.sparse.mm(value1, value2).to_dense()
|
305 |
+
|
306 |
+
norm = value.norm(dim=-1).reshape(-1)
|
307 |
+
norm1 = norm[:, None].expand(-1, NT)
|
308 |
+
norm2 = norm[None, :].expand(NT, -1)
|
309 |
+
cosine = cosine / (norm1 * norm2 + 1e-10)
|
310 |
+
|
311 |
+
v = torch.zeros(NP, NWW + NT, NWW + NT,
|
312 |
+
dtype=value.dtype, device=value.device)
|
313 |
+
v[:, NWW:, NWW:] = cosine
|
314 |
+
|
315 |
+
elif isinstance(f, ContinuousFeature):
|
316 |
+
if f.name.endswith("_matrix"):
|
317 |
+
v = problem[k]
|
318 |
+
elif f.name.startswith("worker_task_"):
|
319 |
+
v = problem[k]
|
320 |
+
if v.dim() == 3:
|
321 |
+
new_v = torch.zeros(NP, NWW + NT, NWW + NT,
|
322 |
+
dtype=v.dtype, device=v.device)
|
323 |
+
else:
|
324 |
+
new_v = torch.zeros(NP, NWW + NT, NWW + NT, v.size(3),
|
325 |
+
dtype=v.dtype, device=v.device)
|
326 |
+
problem_index = torch.arange(NP, device=v.device)[:, None, None]
|
327 |
+
worker_index = torch.arange(NW, device=v.device)[None, :, None]
|
328 |
+
task_index = torch.arange(NT, device=v.device)[None, None, :] + NW + NW
|
329 |
+
new_v[problem_index, worker_index, task_index] = v
|
330 |
+
new_v[problem_index, task_index, worker_index] = v
|
331 |
+
new_v[problem_index, worker_index + NW, task_index] = v
|
332 |
+
new_v[problem_index, task_index, worker_index + NW] = v
|
333 |
+
v = new_v
|
334 |
+
else:
|
335 |
+
raise Exception("feature: {}".format(f.name))
|
336 |
+
else:
|
337 |
+
raise Exception("feature: {}, type: {}".format(k, type(f)))
|
338 |
+
|
339 |
+
if v.dim() == 3:
|
340 |
+
v = v[:, :, :, None]
|
341 |
+
|
342 |
+
assert dim == v.size(-1), \
|
343 |
+
"feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
|
344 |
+
|
345 |
+
feature_list.append(v.float())
|
346 |
+
|
347 |
+
x = torch.cat(feature_list, 3)
|
348 |
+
return self.nn_dense_edge(x)
|
349 |
+
|
greedrl/feature.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def continuous_feature(name):
|
2 |
+
return ContinuousFeature(name)
|
3 |
+
|
4 |
+
|
5 |
+
class ContinuousFeature:
|
6 |
+
def __init__(self, name):
|
7 |
+
self.name = name
|
8 |
+
|
9 |
+
|
10 |
+
def global_category(name, size):
|
11 |
+
return GlobalCategory(name, size)
|
12 |
+
|
13 |
+
|
14 |
+
class GlobalCategory:
|
15 |
+
def __init__(self, name, size):
|
16 |
+
self.name = name
|
17 |
+
self.size = size
|
18 |
+
|
19 |
+
|
20 |
+
def local_category(name):
|
21 |
+
return LocalCategory(name)
|
22 |
+
|
23 |
+
|
24 |
+
class LocalCategory:
|
25 |
+
def __init__(self, name):
|
26 |
+
assert name.startswith('task_'), \
|
27 |
+
"only task feature supported: {}".format(name)
|
28 |
+
self.name = name
|
29 |
+
|
30 |
+
|
31 |
+
def local_feature(name):
|
32 |
+
return LocalFeature(name)
|
33 |
+
|
34 |
+
|
35 |
+
class LocalFeature:
|
36 |
+
def __init__(self, name):
|
37 |
+
assert name.startswith('task_'), \
|
38 |
+
"only task feature supported: {}".format(name)
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
|
42 |
+
def sparse_local_feature(index, value):
|
43 |
+
return SparseLocalFeature(index, value)
|
44 |
+
|
45 |
+
|
46 |
+
class SparseLocalFeature:
|
47 |
+
def __init__(self, index, value):
|
48 |
+
assert index.startswith('task_'), \
|
49 |
+
"only task feature supported for index: {}".format(index)
|
50 |
+
assert value.startswith('task_'), \
|
51 |
+
"only task feature supported for value: {}".format(value)
|
52 |
+
|
53 |
+
self.index = index
|
54 |
+
self.value = value
|
55 |
+
|
56 |
+
|
57 |
+
def variable_feature(name):
|
58 |
+
return VariableFeature(name)
|
59 |
+
|
60 |
+
|
61 |
+
class VariableFeature:
|
62 |
+
def __init__(self, name):
|
63 |
+
self.name = name
|
greedrl/function.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import greedrl_c
|
2 |
+
from greedrl_c import task_group_priority
|
3 |
+
from greedrl_c import task_group_split
|
4 |
+
|
5 |
+
|
greedrl/norm.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Norm1D(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, dim, ntype='batch', affine=False):
|
9 |
+
super(Norm1D, self).__init__()
|
10 |
+
clazz_dict = {'batch': nn.BatchNorm1d, 'instance': nn.InstanceNorm1d}
|
11 |
+
self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.nn_norm(x.permute(0, 2, 1)).permute(0, 2, 1)
|
15 |
+
|
16 |
+
|
17 |
+
class Norm2D(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, dim, ntype='batch', affine=False):
|
20 |
+
super(Norm2D, self).__init__()
|
21 |
+
clazz_dict = {'batch': nn.BatchNorm2d, 'instance': nn.InstanceNorm2d}
|
22 |
+
self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.nn_norm(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
greedrl/pyenv.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
|
5 |
+
from collections import OrderedDict
|
6 |
+
from .const import *
|
7 |
+
from .utils import to_list
|
8 |
+
from .norm import Norm1D, Norm2D
|
9 |
+
from .variable import AttributeVariable, WorkerTaskSequence
|
10 |
+
|
11 |
+
|
12 |
+
class PyEnv(object):
|
13 |
+
|
14 |
+
def __init__(self, problem, batch_size, sample_num, nn_args):
|
15 |
+
super(PyEnv, self).__init__()
|
16 |
+
|
17 |
+
self._problem = problem
|
18 |
+
self._batch_size = batch_size
|
19 |
+
self._sample_num = sample_num
|
20 |
+
self._debug = -1
|
21 |
+
|
22 |
+
self._NW = problem.worker_num
|
23 |
+
self._NWW = problem.worker_num * 2
|
24 |
+
self._NT = problem.task_num
|
25 |
+
self._NWWT = self._NWW + self._NT
|
26 |
+
|
27 |
+
self._feats_dict = nn_args['feature_dict']
|
28 |
+
self._vars_dim = nn_args['variable_dim']
|
29 |
+
|
30 |
+
self._vars_dict = {}
|
31 |
+
self._vars = [var(problem, batch_size, sample_num) for var in problem.variables]
|
32 |
+
for variable in self._vars:
|
33 |
+
save_variable_version(variable)
|
34 |
+
assert variable.name not in self._vars_dict, \
|
35 |
+
"duplicated variable, name: {}".format(variable.name)
|
36 |
+
self._vars_dict[variable.name] = variable
|
37 |
+
|
38 |
+
self._constraint = problem.constraint()
|
39 |
+
self._objective = problem.objective()
|
40 |
+
|
41 |
+
self._worker_index = torch.full((self._batch_size,), -1,
|
42 |
+
dtype=torch.int64,
|
43 |
+
device=problem.device)
|
44 |
+
|
45 |
+
self._batch_index = torch.arange(self._batch_size,
|
46 |
+
dtype=torch.int64,
|
47 |
+
device=problem.device)
|
48 |
+
|
49 |
+
self._problem_index = torch.div(self._batch_index, sample_num, rounding_mode='trunc') # self._batch_index // sample_num
|
50 |
+
|
51 |
+
self._feasible = torch.ones(self._batch_size,
|
52 |
+
dtype=torch.bool,
|
53 |
+
device=problem.device)
|
54 |
+
|
55 |
+
self._cost = torch.zeros(self._batch_size, self._NT * 2,
|
56 |
+
dtype=torch.float32,
|
57 |
+
device=problem.device)
|
58 |
+
|
59 |
+
self._mask = torch.zeros(self._batch_size,
|
60 |
+
self._NWWT + 1,
|
61 |
+
dtype=torch.bool,
|
62 |
+
device=problem.device)
|
63 |
+
|
64 |
+
self._worker_task_sequence = torch.full((self._batch_size, self._NT * 2, 3), -1,
|
65 |
+
dtype=torch.int64,
|
66 |
+
device=problem.device)
|
67 |
+
self._step = 0
|
68 |
+
self.register_variables(self._constraint)
|
69 |
+
self._finished = self._constraint.finished()
|
70 |
+
|
71 |
+
if hasattr(self._constraint, 'mask_worker_start'):
|
72 |
+
self.register_variables(self._constraint)
|
73 |
+
mask_start = self._constraint.mask_worker_start()
|
74 |
+
else:
|
75 |
+
mask_start = False
|
76 |
+
|
77 |
+
self._mask[:, :self._NW] = mask_start
|
78 |
+
self._mask[:, self._NW:] = True
|
79 |
+
|
80 |
+
if self._debug >= 0:
|
81 |
+
print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
|
82 |
+
print("new env")
|
83 |
+
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n")
|
84 |
+
|
85 |
+
def time(self):
|
86 |
+
return self._step
|
87 |
+
|
88 |
+
def step(self, chosen):
|
89 |
+
with torch.no_grad():
|
90 |
+
self._do_step(chosen)
|
91 |
+
|
92 |
+
def _do_step(self, chosen):
|
93 |
+
if self._debug >= 0:
|
94 |
+
print("----------------------------------------------------------------------")
|
95 |
+
feasible = self._feasible & ~self._mask[self._problem_index, chosen]
|
96 |
+
print("feasible={}".format(feasible[self._debug].tolist()))
|
97 |
+
|
98 |
+
is_start = (chosen >= 0) & (chosen < self._NW)
|
99 |
+
if torch.any(is_start):
|
100 |
+
b_index = self._batch_index[is_start]
|
101 |
+
p_index = self._problem_index[is_start]
|
102 |
+
w_index = chosen[is_start]
|
103 |
+
self.step_worker_start(b_index, p_index, w_index)
|
104 |
+
|
105 |
+
is_end = (chosen >= self._NW) & (chosen < self._NWW)
|
106 |
+
if torch.any(is_end):
|
107 |
+
b_index = self._batch_index[is_end]
|
108 |
+
p_index = self._problem_index[is_end]
|
109 |
+
w_index = chosen[is_end] - self._NW
|
110 |
+
self.step_worker_end(b_index, p_index, w_index)
|
111 |
+
|
112 |
+
is_task = (chosen >= self._NWW) & (chosen < self._NWWT)
|
113 |
+
if torch.any(is_task):
|
114 |
+
b_index = self._batch_index[is_task]
|
115 |
+
p_index = self._problem_index[is_task]
|
116 |
+
t_index = chosen[is_task] - self._NWW
|
117 |
+
step_task_b_index = b_index
|
118 |
+
self.step_task(b_index, p_index, t_index)
|
119 |
+
else:
|
120 |
+
step_task_b_index = None
|
121 |
+
|
122 |
+
is_finish = chosen == self._NWWT
|
123 |
+
if torch.any(is_finish):
|
124 |
+
b_index = self._batch_index[is_finish]
|
125 |
+
self._worker_task_sequence[b_index, self._step, 0] = GRL_FINISH
|
126 |
+
self._worker_task_sequence[b_index, self._step, 1] = 0
|
127 |
+
self._worker_task_sequence[b_index, self._step, 2] = -1
|
128 |
+
|
129 |
+
self.update_mask(step_task_b_index)
|
130 |
+
|
131 |
+
for var in self._vars:
|
132 |
+
check_variable_version(var)
|
133 |
+
|
134 |
+
if self._debug >= 0:
|
135 |
+
print("worker_task_sequence[{}]={}".format(self._step,
|
136 |
+
self._worker_task_sequence[self._debug, self._step].tolist()))
|
137 |
+
for var in self._vars:
|
138 |
+
if var.value is None:
|
139 |
+
print("{}={}".format(var.name, None))
|
140 |
+
elif isinstance(var, AttributeVariable):
|
141 |
+
print("{}={}".format(var.name, to_list(var.value)))
|
142 |
+
else:
|
143 |
+
print("{}={}".format(var.name, to_list(var.value[self._debug])))
|
144 |
+
|
145 |
+
self._step += 1
|
146 |
+
if self._step >= self._cost.size(1):
|
147 |
+
cost = torch.zeros(self._batch_size, self._step + self._NT,
|
148 |
+
dtype=torch.float32,
|
149 |
+
device=chosen.device)
|
150 |
+
cost[:, 0:self._step] = self._cost;
|
151 |
+
self._cost = cost
|
152 |
+
|
153 |
+
worker_task_sequence = torch.full((self._batch_size, self._step + self._NT, 3), -1,
|
154 |
+
dtype=torch.int64,
|
155 |
+
device=chosen.device)
|
156 |
+
worker_task_sequence[:, 0:self._step, :] = self._worker_task_sequence
|
157 |
+
self._worker_task_sequence = worker_task_sequence
|
158 |
+
|
159 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
160 |
+
self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_START
|
161 |
+
self._worker_task_sequence[b_index, self._step, 1] = w_index
|
162 |
+
self._worker_task_sequence[b_index, self._step, 2] = -1
|
163 |
+
for var in self._vars:
|
164 |
+
if hasattr(var, 'step_worker_start'):
|
165 |
+
var.step_worker_start(b_index, p_index, w_index)
|
166 |
+
save_variable_version(var)
|
167 |
+
|
168 |
+
if hasattr(self._objective, 'step_worker_start'):
|
169 |
+
self.register_variables(self._objective, b_index)
|
170 |
+
self.update_cost(self._objective.step_worker_start(), b_index)
|
171 |
+
|
172 |
+
self._worker_index[b_index] = w_index
|
173 |
+
self._mask[b_index, :self._NWW] = True
|
174 |
+
self._mask[b_index, self._NWW:] = False
|
175 |
+
|
176 |
+
def step_worker_end(self, b_index, p_index, w_index):
|
177 |
+
self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_END
|
178 |
+
self._worker_task_sequence[b_index, self._step, 1] = w_index
|
179 |
+
self._worker_task_sequence[b_index, self._step, 2] = -1;
|
180 |
+
|
181 |
+
for var in self._vars:
|
182 |
+
if hasattr(var, 'step_worker_end'):
|
183 |
+
var.step_worker_end(b_index, p_index, w_index)
|
184 |
+
save_variable_version(var)
|
185 |
+
|
186 |
+
if hasattr(self._objective, 'step_worker_end'):
|
187 |
+
self.register_variables(self._objective, b_index)
|
188 |
+
self.update_cost(self._objective.step_worker_end(), b_index)
|
189 |
+
|
190 |
+
self._worker_index[b_index] = -1
|
191 |
+
|
192 |
+
self.register_variables(self._constraint, b_index)
|
193 |
+
self._finished[b_index] |= self._constraint.finished()
|
194 |
+
if hasattr(self._constraint, 'mask_worker_start'):
|
195 |
+
mask_start = self._constraint.mask_worker_start()
|
196 |
+
else:
|
197 |
+
mask_start = False
|
198 |
+
|
199 |
+
self._mask[b_index, :self._NW] = mask_start
|
200 |
+
self._mask[b_index, self._NW:] = True
|
201 |
+
|
202 |
+
def step_task(self, b_index, p_index, t_index):
|
203 |
+
self._worker_task_sequence[b_index, self._step, 0] = GRL_TASK
|
204 |
+
self._worker_task_sequence[b_index, self._step, 1] = t_index
|
205 |
+
|
206 |
+
for var in self._vars:
|
207 |
+
if not hasattr(var, 'step_task'):
|
208 |
+
continue
|
209 |
+
elif var.step_task.__code__.co_argcount == 4:
|
210 |
+
var.step_task(b_index, p_index, t_index)
|
211 |
+
else:
|
212 |
+
var.step_task(b_index, p_index, t_index, None)
|
213 |
+
save_variable_version(var)
|
214 |
+
|
215 |
+
if hasattr(self._constraint, 'do_task'):
|
216 |
+
self.register_variables(self._constraint, b_index)
|
217 |
+
done = self._constraint.do_task()
|
218 |
+
self._worker_task_sequence[b_index, self._step, 2] = done.long()
|
219 |
+
|
220 |
+
for var in self._vars:
|
221 |
+
if not hasattr(var, 'step_task'):
|
222 |
+
continue
|
223 |
+
elif var.step_task.__code__.co_argcount == 4:
|
224 |
+
pass
|
225 |
+
else:
|
226 |
+
check_variable_version(var)
|
227 |
+
var.step_task(b_index, p_index, t_index, done)
|
228 |
+
save_variable_version(var)
|
229 |
+
else:
|
230 |
+
done = None
|
231 |
+
|
232 |
+
if hasattr(self._objective, 'step_task'):
|
233 |
+
self.register_variables(self._objective, b_index)
|
234 |
+
self.update_cost(self._objective.step_task(), b_index)
|
235 |
+
|
236 |
+
if hasattr(self._constraint, 'mask_worker_end'):
|
237 |
+
self.register_variables(self._constraint, b_index)
|
238 |
+
mask_end = self._constraint.mask_worker_end()
|
239 |
+
else:
|
240 |
+
mask_end = False
|
241 |
+
|
242 |
+
w_index = self._NW + self._worker_index[b_index]
|
243 |
+
self._mask[b_index, w_index] = mask_end
|
244 |
+
self._mask[b_index, self._NWW:] = False
|
245 |
+
return done
|
246 |
+
|
247 |
+
def update_cost(self, cost, b_index=None):
|
248 |
+
if isinstance(cost, tuple):
|
249 |
+
cost, feasible = cost
|
250 |
+
if b_index is None:
|
251 |
+
self._feasible &= feasible
|
252 |
+
else:
|
253 |
+
self._feasible[b_index] &= feasible
|
254 |
+
|
255 |
+
if isinstance(cost, torch.Tensor):
|
256 |
+
cost = cost.float()
|
257 |
+
else:
|
258 |
+
assert type(cost) in (int, float), "unexpected cost's type: {}".format(type(cost))
|
259 |
+
|
260 |
+
if b_index is None:
|
261 |
+
self._cost[:, self._step] = cost
|
262 |
+
else:
|
263 |
+
self._cost[b_index, self._step] = cost
|
264 |
+
|
265 |
+
def update_mask(self, step_task_b_index):
|
266 |
+
self._mask |= self._finished[:, None]
|
267 |
+
self._mask[:, -1] = ~self._finished
|
268 |
+
self.register_variables(self._constraint)
|
269 |
+
self._mask[:, self._NWW:self._NWWT] |= self._constraint.mask_task()
|
270 |
+
|
271 |
+
if step_task_b_index is not None:
|
272 |
+
b_index = step_task_b_index
|
273 |
+
w_index = self._NW + self._worker_index[b_index]
|
274 |
+
task_mask = self._mask[b_index, self._NWW:self._NWWT]
|
275 |
+
self._mask[b_index, w_index] &= ~torch.all(task_mask, 1)
|
276 |
+
|
277 |
+
def batch_size():
|
278 |
+
return self._batch_size
|
279 |
+
|
280 |
+
def sample_num():
|
281 |
+
return self._sample_num
|
282 |
+
|
283 |
+
def mask(self):
|
284 |
+
return self._mask.clone()
|
285 |
+
|
286 |
+
def cost(self):
|
287 |
+
return self._cost[:, 0:self._step]
|
288 |
+
|
289 |
+
def feasible(self):
|
290 |
+
return self._feasible
|
291 |
+
|
292 |
+
def worker_task_sequence(self):
|
293 |
+
return self._worker_task_sequence[:, 0:self._step]
|
294 |
+
|
295 |
+
def var(self, name):
|
296 |
+
return self._vars_dict[name].value
|
297 |
+
|
298 |
+
def register_variables(self, obj, b_index=None, finished=False):
|
299 |
+
for var in self._vars:
|
300 |
+
if var.value is None or b_index is None \
|
301 |
+
or isinstance(var, AttributeVariable):
|
302 |
+
value = var.value
|
303 |
+
else:
|
304 |
+
value = var.value[b_index]
|
305 |
+
obj.__dict__[var.name] = value
|
306 |
+
|
307 |
+
if not hasattr(var, 'ext_values'):
|
308 |
+
continue
|
309 |
+
|
310 |
+
for k, v in var.ext_values.items():
|
311 |
+
k = var.name + '_' + k
|
312 |
+
obj.__dict__[k] = v[b_index]
|
313 |
+
|
314 |
+
def finished(self):
|
315 |
+
return self._finished
|
316 |
+
|
317 |
+
def all_finished(self):
|
318 |
+
return torch.all(self.finished())
|
319 |
+
|
320 |
+
def finalize(self):
|
321 |
+
self._worker_task_sequence[:, self._step, 0] = GRL_FINISH
|
322 |
+
self._worker_task_sequence[:, self._step, 1] = 0
|
323 |
+
self._worker_task_sequence[:, self._step, 2] = -1
|
324 |
+
|
325 |
+
for var in self._vars:
|
326 |
+
if hasattr(var, 'step_finish'):
|
327 |
+
var.step_finish(self.worker_task_sequence())
|
328 |
+
|
329 |
+
if hasattr(self._objective, 'step_finish'):
|
330 |
+
self.register_variables(self._objective, finished=True)
|
331 |
+
self.update_cost(self._objective.step_finish())
|
332 |
+
|
333 |
+
self._step += 1
|
334 |
+
|
335 |
+
def make_feat(self):
|
336 |
+
with torch.no_grad():
|
337 |
+
return self.do_make_feat()
|
338 |
+
|
339 |
+
def do_make_feat(self):
|
340 |
+
if not self._vars_dim:
|
341 |
+
return None
|
342 |
+
|
343 |
+
feature_list = []
|
344 |
+
for k, dim in self._vars_dim.items():
|
345 |
+
f = self._feats_dict[k]
|
346 |
+
var = self._vars_dict[f.name]
|
347 |
+
v = var.make_feat()
|
348 |
+
if v.dim() == 2:
|
349 |
+
v = v[:, :, None]
|
350 |
+
|
351 |
+
assert dim == v.size(-1), \
|
352 |
+
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
|
353 |
+
feature_list.append(v.float())
|
354 |
+
|
355 |
+
v = torch.cat(feature_list, 2)
|
356 |
+
u = v.new_zeros(v.size(0), self._NWW, v.size(2))
|
357 |
+
f = v.new_zeros(v.size(0), 1, v.size(2))
|
358 |
+
v = torch.cat([u, v, f], 1).permute(0, 2, 1)
|
359 |
+
|
360 |
+
v[self._mask[:, None, :].expand(v.size())] = 0
|
361 |
+
|
362 |
+
norm = v.new_ones(self._mask.size())
|
363 |
+
norm[self._mask] = 0
|
364 |
+
norm = norm.sum(1) + 1e-10
|
365 |
+
norm = norm[:, None, None]
|
366 |
+
|
367 |
+
avg = v.sum(-1, keepdim=True) / norm
|
368 |
+
v = v - avg
|
369 |
+
|
370 |
+
std = v.norm(dim=-1, keepdim=True) / norm + 1e-10
|
371 |
+
v = v / std
|
372 |
+
return v.contiguous()
|
373 |
+
|
374 |
+
|
375 |
+
def save_variable_version(var):
|
376 |
+
if isinstance(var.value, torch.Tensor):
|
377 |
+
var.__version__ = var.value._version
|
378 |
+
|
379 |
+
|
380 |
+
def check_variable_version(var):
|
381 |
+
if isinstance(var.value, torch.Tensor):
|
382 |
+
assert var.__version__ == var.value._version, \
|
383 |
+
"variable's value is modified, name: {}".format(var.name)
|
greedrl/solver.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import time
|
6 |
+
import queue
|
7 |
+
import inspect
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
from .agent import Agent, parse_nn_args
|
14 |
+
from .utils import repeat, get_default_device, cutime_stats
|
15 |
+
from .variable import TaskDemandNow
|
16 |
+
|
17 |
+
from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters
|
18 |
+
from torch.utils.data import Dataset, IterableDataset, DataLoader
|
19 |
+
from torch.optim.lr_scheduler import MultiStepLR
|
20 |
+
|
21 |
+
|
22 |
+
class Problem(object):
|
23 |
+
def __init__(self, isbatch=False):
|
24 |
+
self.isbatch = isbatch
|
25 |
+
self.features = []
|
26 |
+
self.environment = None
|
27 |
+
|
28 |
+
def pin_memory(self):
|
29 |
+
for k, v in self.feats.items():
|
30 |
+
self.feats[k] = v.pin_memory()
|
31 |
+
return self
|
32 |
+
|
33 |
+
def __getattr__(self, name):
|
34 |
+
if name not in ('solution'):
|
35 |
+
raise AttributeError()
|
36 |
+
return self.feats.get(name)
|
37 |
+
|
38 |
+
|
39 |
+
class Solution(object):
|
40 |
+
def __init__(self, cost=None):
|
41 |
+
self.cost = cost
|
42 |
+
self.worker_task_sequence = None
|
43 |
+
|
44 |
+
|
45 |
+
class WrapDataset(Dataset):
|
46 |
+
def __init__(self, dataset, solver):
|
47 |
+
self._dataset = [solver.to_batch(p) for p in dataset]
|
48 |
+
|
49 |
+
def __getitem__(self, index):
|
50 |
+
return self._dataset[index]
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self._dataset)
|
54 |
+
|
55 |
+
|
56 |
+
class WrapIterator:
|
57 |
+
def __init__(self, iterator, solver):
|
58 |
+
self._iterator = iterator
|
59 |
+
self._solver = solver
|
60 |
+
|
61 |
+
def __next__(self):
|
62 |
+
p = next(self._iterator)
|
63 |
+
p = self._solver.to_batch(p, False)
|
64 |
+
return p
|
65 |
+
|
66 |
+
|
67 |
+
class WrapIterableDataset(IterableDataset):
|
68 |
+
def __init__(self, dataset, solver):
|
69 |
+
self._dataset = dataset
|
70 |
+
self._solver = solver
|
71 |
+
|
72 |
+
def __iter__(self):
|
73 |
+
return WrapIterator(iter(self._dataset), self._solver)
|
74 |
+
|
75 |
+
|
76 |
+
class CyclicIterator:
|
77 |
+
def __init__(self, iterable):
|
78 |
+
self._iterable = iterable
|
79 |
+
self._iterator = iter(iterable)
|
80 |
+
|
81 |
+
def __iter__(self):
|
82 |
+
return self
|
83 |
+
|
84 |
+
def __next__(self):
|
85 |
+
try:
|
86 |
+
return next(self._iterator)
|
87 |
+
except StopIteration:
|
88 |
+
self._iterator = iter(self._iterable)
|
89 |
+
return next(self._iterator)
|
90 |
+
|
91 |
+
|
92 |
+
class BufferedIterator:
|
93 |
+
def __init__(self, iterator, size, reuse):
|
94 |
+
self._iterator = iterator
|
95 |
+
self._reuse = reuse
|
96 |
+
self._queue = queue.Queue(size)
|
97 |
+
self._buffer = []
|
98 |
+
self._iter_step = 0
|
99 |
+
|
100 |
+
def __next__(self):
|
101 |
+
if not self._queue.full() or self._iter_step % self._reuse == 0:
|
102 |
+
problem = next(self._iterator)
|
103 |
+
if self._queue.full():
|
104 |
+
index = self._queue.get()
|
105 |
+
self._buffer[index] = problem
|
106 |
+
else:
|
107 |
+
index = len(self._buffer)
|
108 |
+
self._buffer.append(problem)
|
109 |
+
self._queue.put(index)
|
110 |
+
self._iter_step += 1
|
111 |
+
index = torch.randint(0, len(self._buffer), (1,)).item()
|
112 |
+
return self._buffer[index]
|
113 |
+
|
114 |
+
|
115 |
+
class Solver(object):
|
116 |
+
def __init__(self, device=None, nn_args=None):
|
117 |
+
|
118 |
+
if device is None:
|
119 |
+
self.device = get_default_device()
|
120 |
+
elif device == 'cuda':
|
121 |
+
self.device = get_default_device()
|
122 |
+
assert self.device.type == 'cuda', 'no cuda device available!'
|
123 |
+
else:
|
124 |
+
self.device = torch.device(device)
|
125 |
+
|
126 |
+
if nn_args is None:
|
127 |
+
nn_args = {}
|
128 |
+
self.nn_args = nn_args
|
129 |
+
|
130 |
+
self.agent = None
|
131 |
+
|
132 |
+
def parse_nn_args(self, problem):
|
133 |
+
parse_nn_args(problem, self.nn_args)
|
134 |
+
|
135 |
+
def new_agent(self):
|
136 |
+
return Agent(self.nn_args)
|
137 |
+
|
138 |
+
def train(self, agent_filename, train_dataset, valid_dataset, **kwargs):
|
139 |
+
if dist.is_initialized():
|
140 |
+
torch.manual_seed(torch.initial_seed() + dist.get_rank() * 20000)
|
141 |
+
|
142 |
+
train_dataset_workers = kwargs.pop('train_dataset_workers', 1)
|
143 |
+
train_dataset_buffers = kwargs.pop('train_dataset_buffers', 2)
|
144 |
+
valid_dataset_workers = kwargs.pop('valid_dataset_workers', 1)
|
145 |
+
valid_dataset_buffers = kwargs.pop('valid_dataset_buffers', 2)
|
146 |
+
|
147 |
+
train_dataset = self.wrap_dataset(train_dataset, train_dataset_workers,
|
148 |
+
train_dataset_buffers, torch.initial_seed() + 1)
|
149 |
+
valid_dataset = self.wrap_dataset(valid_dataset, valid_dataset_workers,
|
150 |
+
valid_dataset_buffers, torch.initial_seed() + 10001)
|
151 |
+
|
152 |
+
if self.device.type == 'cuda':
|
153 |
+
with torch.cuda.device(cuda_or_none(self.device)):
|
154 |
+
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
|
155 |
+
else:
|
156 |
+
self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
|
157 |
+
|
158 |
+
def do_train(self, agent_filename, train_dataset, valid_dataset, reuse_buffer=0, reuse_times=1, on_policy=True,
|
159 |
+
advpow=1, batch_size=512, topk_size=1, init_lr=0.0001, sched_lr=(int(1e10),), gamma_lr=0.5,
|
160 |
+
warmup_steps=100, log_steps=-1, optim_steps=1, valid_steps=100, max_steps=int(1e10), memopt=1):
|
161 |
+
|
162 |
+
for arg in inspect.getfullargspec(self.do_train)[0][1:]:
|
163 |
+
if arg not in ('train_dataset', 'valid_dataset'):
|
164 |
+
print("train_args: {} = {}".format(arg, locals()[arg]))
|
165 |
+
|
166 |
+
if log_steps < 0:
|
167 |
+
log_steps = valid_steps
|
168 |
+
|
169 |
+
train_dataset = CyclicIterator(train_dataset)
|
170 |
+
if reuse_buffer > 0:
|
171 |
+
train_dataset = BufferedIterator(train_dataset, reuse_buffer, reuse_times)
|
172 |
+
|
173 |
+
valid_dataset = list(valid_dataset)
|
174 |
+
|
175 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
176 |
+
dist.barrier()
|
177 |
+
|
178 |
+
if agent_filename is not None and os.path.exists(agent_filename):
|
179 |
+
saved_state = torch.load(agent_filename, map_location='cpu')
|
180 |
+
self.nn_args = saved_state['nn_args']
|
181 |
+
else:
|
182 |
+
saved_state = None
|
183 |
+
self.parse_nn_args(valid_dataset[0])
|
184 |
+
|
185 |
+
step = 0
|
186 |
+
start_step = 0
|
187 |
+
self.agent = self.new_agent().train()
|
188 |
+
self.agent.to(self.device)
|
189 |
+
self.print_nn_args()
|
190 |
+
|
191 |
+
best_agent = copy.deepcopy(self.agent).eval()
|
192 |
+
min_valid_cost = math.inf
|
193 |
+
|
194 |
+
optimizer = torch.optim.Adam(self.agent.parameters(), lr=init_lr)
|
195 |
+
scheduler = MultiStepLR(optimizer, milestones=sched_lr, gamma=gamma_lr)
|
196 |
+
|
197 |
+
def do_save_state(rng_state, cuda_rng_state):
|
198 |
+
if agent_filename is not None:
|
199 |
+
save_data = {'step': step, 'rng_state': rng_state}
|
200 |
+
if cuda_rng_state is not None:
|
201 |
+
save_data['cuda_rng_state'] = cuda_rng_state
|
202 |
+
save_data['nn_args'] = self.agent.nn_args_dict()
|
203 |
+
save_data['agent_state'] = self.agent.state_dict()
|
204 |
+
save_data['best_agent_state'] = best_agent.state_dict()
|
205 |
+
save_data['optimizer_state'] = optimizer.state_dict()
|
206 |
+
save_data['scheduler_state'] = scheduler.state_dict()
|
207 |
+
torch.save(save_data, agent_filename)
|
208 |
+
|
209 |
+
def valid_sched_save(step):
|
210 |
+
if dist.is_initialized():
|
211 |
+
params = parameters_to_vector(self.agent.parameters())
|
212 |
+
params_clone = params.clone()
|
213 |
+
dist.broadcast(params_clone, 0)
|
214 |
+
assert torch.all(params == params_clone)
|
215 |
+
|
216 |
+
rng_state = torch.get_rng_state()
|
217 |
+
cuda_rng_state = None
|
218 |
+
if self.device.type == 'cuda':
|
219 |
+
cuda_rng_state = torch.cuda.get_rng_state(self.device)
|
220 |
+
|
221 |
+
print("{} - step={}, validate...".format(time.strftime("%Y-%m-%d %H:%M:%S"), step))
|
222 |
+
sys.stdout.flush()
|
223 |
+
|
224 |
+
if self.device.type == 'cuda':
|
225 |
+
torch.cuda.synchronize(self.device)
|
226 |
+
start_time = time.time()
|
227 |
+
valid_result = self.validate(valid_dataset, batch_size)
|
228 |
+
avg_cost1, avg_cost2, avg_feasible = valid_result
|
229 |
+
if self.device.type == 'cuda':
|
230 |
+
torch.cuda.synchronize(self.device)
|
231 |
+
|
232 |
+
duration = time.time() - start_time
|
233 |
+
|
234 |
+
if step > 0:
|
235 |
+
scheduler.step()
|
236 |
+
|
237 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
238 |
+
do_save_state(rng_state, cuda_rng_state)
|
239 |
+
|
240 |
+
strftime = time.strftime("%Y-%m-%d %H:%M:%S")
|
241 |
+
print("{} - step={}, cost=[{:.6g}, {:.6g}], feasible={:.0%}".format(
|
242 |
+
strftime, step, avg_cost1, avg_cost2, avg_feasible))
|
243 |
+
print("{} - step={}, min_valid_cost={:.6g}, time={:.3f}s".format(
|
244 |
+
strftime, step, min(min_valid_cost, avg_cost2), duration))
|
245 |
+
print("---------------------------------------------------------------------------------------")
|
246 |
+
sys.stdout.flush()
|
247 |
+
return avg_cost2
|
248 |
+
|
249 |
+
if saved_state is not None:
|
250 |
+
start_step = saved_state['step']
|
251 |
+
|
252 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
253 |
+
torch.set_rng_state(saved_state['rng_state'])
|
254 |
+
if torch.cuda.is_available():
|
255 |
+
torch.cuda.set_rng_state(saved_state['cuda_rng_state'], self.device)
|
256 |
+
|
257 |
+
best_agent.load_state_dict(saved_state['best_agent_state'])
|
258 |
+
self.agent.load_state_dict(saved_state['best_agent_state'])
|
259 |
+
|
260 |
+
# if 'agent_state' in saved_state:
|
261 |
+
# self.agent.load_state_dict(saved_state['agent_state'])
|
262 |
+
# else:
|
263 |
+
# self.agent.load_state_dict(saved_state['best_agent_state'])
|
264 |
+
|
265 |
+
if 'optimizer_state' in saved_state:
|
266 |
+
optimizer.load_state_dict(saved_state['optimizer_state'])
|
267 |
+
if 'scheduler_state' in saved_state:
|
268 |
+
scheduler.load_state_dict(saved_state['scheduler_state'])
|
269 |
+
else:
|
270 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
271 |
+
rng_state = torch.get_rng_state()
|
272 |
+
cuda_rng_state = None
|
273 |
+
if self.device.type == 'cuda':
|
274 |
+
cuda_rng_state = torch.cuda.get_rng_state(self.device)
|
275 |
+
do_save_state(rng_state, cuda_rng_state)
|
276 |
+
|
277 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
278 |
+
dist.barrier()
|
279 |
+
|
280 |
+
for step in range(start_step, max_steps):
|
281 |
+
if step % valid_steps == 0:
|
282 |
+
valid_cost = valid_sched_save(step)
|
283 |
+
if valid_cost < min_valid_cost:
|
284 |
+
best_agent.load_state_dict(self.agent.state_dict())
|
285 |
+
min_valid_cost = valid_cost
|
286 |
+
|
287 |
+
start_time = time.time()
|
288 |
+
|
289 |
+
# problem
|
290 |
+
with torch.no_grad():
|
291 |
+
problem = next(train_dataset)
|
292 |
+
if step < warmup_steps:
|
293 |
+
batch_size_now = batch_size // 2
|
294 |
+
else:
|
295 |
+
batch_size_now = batch_size
|
296 |
+
problem = self.to_device(problem)
|
297 |
+
|
298 |
+
if not on_policy:
|
299 |
+
data_agent = best_agent
|
300 |
+
else:
|
301 |
+
data_agent = self.agent
|
302 |
+
|
303 |
+
data_agent.eval()
|
304 |
+
|
305 |
+
# solution
|
306 |
+
if topk_size > 1:
|
307 |
+
with torch.no_grad():
|
308 |
+
batch_size_topk = batch_size_now * topk_size
|
309 |
+
env, logp = data_agent(problem, batch_size_topk)
|
310 |
+
cost = env.cost().sum(1).float()
|
311 |
+
solution = env.worker_task_sequence()
|
312 |
+
|
313 |
+
NP = problem.batch_size
|
314 |
+
NK = batch_size_now // NP
|
315 |
+
NS = solution.size(1)
|
316 |
+
|
317 |
+
cost = cost.view(NP, -1)
|
318 |
+
cost, kidx = cost.topk(NK, 1, False, False)
|
319 |
+
cost = cost.view(-1)
|
320 |
+
kidx = kidx[:, :, None, None].expand(-1, -1, NS, 3)
|
321 |
+
solution = solution.view(NP, -1, NS, 3)
|
322 |
+
solution = solution.gather(1, kidx).view(-1, NS, 3)
|
323 |
+
|
324 |
+
elif not on_policy:
|
325 |
+
with torch.no_grad():
|
326 |
+
env, logp = data_agent(problem, batch_size_now)
|
327 |
+
cost = env.cost().sum(1).float()
|
328 |
+
solution = env.worker_task_sequence()
|
329 |
+
else:
|
330 |
+
self.agent.train()
|
331 |
+
env, logp = self.agent(problem, batch_size_now, memopt=memopt)
|
332 |
+
cost = env.cost().sum(1).float()
|
333 |
+
solution = env.worker_task_sequence()
|
334 |
+
|
335 |
+
self.agent.train()
|
336 |
+
|
337 |
+
# advantage
|
338 |
+
with torch.no_grad():
|
339 |
+
NP = problem.batch_size
|
340 |
+
if topk_size > 1:
|
341 |
+
baseline = cost.view(NP, -1).max(1)[0]
|
342 |
+
else:
|
343 |
+
baseline = cost.view(NP, -1).mean(1)
|
344 |
+
baseline = repeat(baseline, cost.size(0) // NP)
|
345 |
+
adv = (cost - baseline)[:, None]
|
346 |
+
adv_norm = adv.norm()
|
347 |
+
if adv_norm > 0:
|
348 |
+
adv = adv / adv.norm() * adv.size(0)
|
349 |
+
adv = adv.sign() * adv.abs().pow(advpow)
|
350 |
+
|
351 |
+
# backward
|
352 |
+
if topk_size > 1 or not on_policy:
|
353 |
+
env, logp = self.agent(problem, batch_size_now, solution=solution, memopt=memopt)
|
354 |
+
|
355 |
+
loss = adv * logp
|
356 |
+
loss = loss.mean()
|
357 |
+
loss.backward()
|
358 |
+
|
359 |
+
if step % optim_steps == 0:
|
360 |
+
if dist.is_initialized():
|
361 |
+
params = filter(lambda a: a.grad is not None, self.agent.parameters())
|
362 |
+
grad_list = [param.grad for param in params]
|
363 |
+
grad_vector = parameters_to_vector(grad_list)
|
364 |
+
dist.all_reduce(grad_vector, op=dist.ReduceOp.SUM)
|
365 |
+
vector_to_parameters(grad_vector, grad_list)
|
366 |
+
|
367 |
+
grad_norm = clip_grad_norm_(self.agent.parameters(), 1)
|
368 |
+
optimizer.step()
|
369 |
+
optimizer.zero_grad()
|
370 |
+
|
371 |
+
if step % log_steps == 0:
|
372 |
+
strftime = time.strftime("%Y-%m-%d %H:%M:%S")
|
373 |
+
lr = optimizer.param_groups[0]['lr']
|
374 |
+
duration = time.time() - start_time
|
375 |
+
with torch.no_grad():
|
376 |
+
p = logp.to(torch.float64).sum(1).exp().mean()
|
377 |
+
print("{} - step={}, grad={:.6g}, lr={:.6g}, p={:.6g}".format(
|
378 |
+
strftime, step, grad_norm, lr, p))
|
379 |
+
|
380 |
+
print("{} - step={}, cost={:.6g}, time={:.3f}s".format(strftime, step, cost.mean(), duration))
|
381 |
+
print("---------------------------------------------------------------------------------------")
|
382 |
+
sys.stdout.flush()
|
383 |
+
|
384 |
+
valid_sched_save(step)
|
385 |
+
|
386 |
+
def solve(self, problem, greedy=False, batch_size=512):
|
387 |
+
if self.device.type == 'cuda':
|
388 |
+
with torch.cuda.device(cuda_or_none(self.device)):
|
389 |
+
return self.do_solve(problem, greedy, batch_size)
|
390 |
+
else:
|
391 |
+
return self.do_solve(problem, greedy, batch_size)
|
392 |
+
|
393 |
+
def do_solve(self, problem, greedy, batch_size):
|
394 |
+
isbatch = problem.isbatch
|
395 |
+
problem = self.to_batch(problem)
|
396 |
+
problem = self.to_device(problem)
|
397 |
+
|
398 |
+
if self.agent is None:
|
399 |
+
self.parse_nn_args(problem)
|
400 |
+
self.agent = self.new_agent()
|
401 |
+
self.agent.to(self.device)
|
402 |
+
|
403 |
+
self.agent.eval()
|
404 |
+
|
405 |
+
with torch.no_grad():
|
406 |
+
env, prob = self.agent(problem, batch_size, greedy, problem.solution)
|
407 |
+
|
408 |
+
NP = problem.batch_size
|
409 |
+
NR = prob.size(0) // NP
|
410 |
+
|
411 |
+
prob = prob.view(NP, NR, -1)
|
412 |
+
cost = env.cost().sum(1).view(NP, NR)
|
413 |
+
feasible = env.feasible().view(NP, NR)
|
414 |
+
size = list(env.worker_task_sequence().size())
|
415 |
+
size = [NP, NR] + size[1:]
|
416 |
+
worker_task_sequence = env.worker_task_sequence().view(size)
|
417 |
+
|
418 |
+
p_index = torch.arange(NP)
|
419 |
+
base_cost = cost.max() + 1
|
420 |
+
cost[~feasible] += base_cost
|
421 |
+
cost, s_index = cost.min(1)
|
422 |
+
feasible = feasible[p_index, s_index]
|
423 |
+
cost[~feasible] -= base_cost
|
424 |
+
probability = prob[p_index, s_index].exp()
|
425 |
+
worker_task_sequence = worker_task_sequence[p_index, s_index]
|
426 |
+
|
427 |
+
if isbatch:
|
428 |
+
solution = Solution(cost)
|
429 |
+
solution.feasible = feasible
|
430 |
+
solution.probability = probability
|
431 |
+
solution.worker_task_sequence = worker_task_sequence
|
432 |
+
else:
|
433 |
+
solution = Solution(cost.item())
|
434 |
+
solution.feasible = feasible.item()
|
435 |
+
solution.probability = probability.squeeze(0)
|
436 |
+
solution.worker_task_sequence = worker_task_sequence.squeeze(0)
|
437 |
+
|
438 |
+
return solution
|
439 |
+
|
440 |
+
def load_agent(self, filename, strict=True):
|
441 |
+
if self.device.type == 'cuda':
|
442 |
+
with torch.cuda.device(cuda_or_none(self.device)):
|
443 |
+
self.do_load_agent(filename, strict)
|
444 |
+
else:
|
445 |
+
self.do_load_agent(filename, strict)
|
446 |
+
|
447 |
+
def do_load_agent(self, filename, strict=True):
|
448 |
+
saved_state = torch.load(filename, map_location='cpu')
|
449 |
+
self.nn_args = saved_state['nn_args']
|
450 |
+
|
451 |
+
self.agent = self.new_agent()
|
452 |
+
self.agent.to(self.device)
|
453 |
+
self.agent.load_state_dict(saved_state['best_agent_state'], strict)
|
454 |
+
self.print_nn_args()
|
455 |
+
|
456 |
+
def to_batch(self, problem, pin_memory=True):
|
457 |
+
assert not hasattr(problem, 'feats')
|
458 |
+
|
459 |
+
NW = 1
|
460 |
+
NT = 1
|
461 |
+
NP = 1
|
462 |
+
isbatch = problem.isbatch
|
463 |
+
for k, v in problem.__dict__.items():
|
464 |
+
if k.startswith("worker_"):
|
465 |
+
NW = len(v[0]) if isbatch else len(v)
|
466 |
+
elif k.startswith("task_"):
|
467 |
+
NP = len(v) if isbatch else 1
|
468 |
+
NT = len(v[0]) if isbatch else len(v)
|
469 |
+
NWW = NW * 2
|
470 |
+
|
471 |
+
new_problem = Problem(True)
|
472 |
+
new_problem.feats = {}
|
473 |
+
new_problem.device = 'cpu'
|
474 |
+
|
475 |
+
new_problem.batch_size = NP
|
476 |
+
new_problem.worker_num = NW
|
477 |
+
new_problem.task_num = NT
|
478 |
+
|
479 |
+
new_problem.features = problem.features
|
480 |
+
|
481 |
+
if type(self) == Solver:
|
482 |
+
new_problem.variables = problem.variables
|
483 |
+
new_problem.constraint = problem.constraint
|
484 |
+
new_problem.objective = problem.objective
|
485 |
+
new_problem.environment = problem.environment
|
486 |
+
else:
|
487 |
+
new_problem.variables = []
|
488 |
+
new_problem.constraints = problem.constraints
|
489 |
+
new_problem.oa_estimate_tasks = problem.oa_estimate_tasks
|
490 |
+
new_problem.oa_multiple_steps = problem.oa_multiple_steps
|
491 |
+
|
492 |
+
edge_size_list = ((NWW + NT, NWW + NT), (NW + NT, NW + NT))
|
493 |
+
|
494 |
+
def check_size(f, k, v):
|
495 |
+
assert f, "size error, feature: {}, size: {}".format(k, tuple(v.size()))
|
496 |
+
|
497 |
+
for k, v in problem.__dict__.items():
|
498 |
+
if k == 'solution' and v is not None:
|
499 |
+
v = to_tensor(k, v, isbatch)
|
500 |
+
check_size(v.dim() == 3 and v.size(-1) == 3, k, v)
|
501 |
+
elif k.startswith("worker_task_"):
|
502 |
+
v = to_tensor(k, v, isbatch)
|
503 |
+
check_size(v.dim() in (3, 4) and v.size()[1:3] == (NW, NT), k, v)
|
504 |
+
elif k.startswith("worker_"):
|
505 |
+
v = to_tensor(k, v, isbatch)
|
506 |
+
check_size(v.dim() in (2, 3) and v.size(1) == NW, k, v)
|
507 |
+
elif k.startswith("task_"):
|
508 |
+
v = to_tensor(k, v, isbatch)
|
509 |
+
check_size(v.dim() in (2, 3) and v.size(1) == NT, k, v)
|
510 |
+
elif k.endswith("_matrix"):
|
511 |
+
v = to_tensor(k, v, isbatch)
|
512 |
+
check_size(v.dim() in (3, 4) and v.size()[1:3] in edge_size_list, k, v)
|
513 |
+
if v.size()[1:3] == (NW + NT, NW + NT):
|
514 |
+
worker_index = torch.arange(NW)
|
515 |
+
task_index = torch.arange(NT) + NW
|
516 |
+
index = torch.cat([worker_index, worker_index, task_index])
|
517 |
+
index1 = index[:, None]
|
518 |
+
index2 = index[None, :]
|
519 |
+
v = v[:, index1, index2]
|
520 |
+
elif isinstance(v, np.ndarray):
|
521 |
+
v = torch.tensor(v)
|
522 |
+
|
523 |
+
if isinstance(v, torch.Tensor):
|
524 |
+
new_problem.feats[k] = v
|
525 |
+
|
526 |
+
if pin_memory and self.device.type == 'cuda':
|
527 |
+
new_problem.pin_memory()
|
528 |
+
|
529 |
+
return new_problem
|
530 |
+
|
531 |
+
def to_device(self, problem):
|
532 |
+
|
533 |
+
assert hasattr(problem, 'feats')
|
534 |
+
|
535 |
+
new_problem = copy.copy(problem)
|
536 |
+
new_problem.device = self.device
|
537 |
+
new_problem.feats = {}
|
538 |
+
|
539 |
+
non_blocking = self.device.type == 'cuda'
|
540 |
+
for k, v in problem.feats.items():
|
541 |
+
v = v.to(self.device, non_blocking=non_blocking)
|
542 |
+
new_problem.feats[k] = v
|
543 |
+
|
544 |
+
return new_problem
|
545 |
+
|
546 |
+
def validate(self, problem_list, batch_size):
|
547 |
+
self.agent.eval()
|
548 |
+
with torch.no_grad():
|
549 |
+
valid_result = self.do_validate(problem_list, batch_size)
|
550 |
+
|
551 |
+
self.agent.train()
|
552 |
+
return valid_result
|
553 |
+
|
554 |
+
def do_validate(self, problem_list, batch_size):
|
555 |
+
total_cost1 = 0
|
556 |
+
total_cost2 = 0
|
557 |
+
total_feasible = 0
|
558 |
+
total_problem = 0
|
559 |
+
start_time = time.time()
|
560 |
+
for problem in problem_list:
|
561 |
+
problem = self.to_device(problem)
|
562 |
+
env, _, = self.agent(problem, batch_size)
|
563 |
+
|
564 |
+
NP = problem.batch_size
|
565 |
+
cost = env.cost().sum(1).view(NP, -1)
|
566 |
+
cost1, _ = cost.min(1)
|
567 |
+
cost2 = cost.mean(1)
|
568 |
+
feasible = env.feasible().view(NP, -1)
|
569 |
+
feasible = torch.any(feasible, 1)
|
570 |
+
|
571 |
+
total_cost1 += cost1.sum().item()
|
572 |
+
total_cost2 += cost2.sum().item()
|
573 |
+
total_feasible += feasible.int().sum().item()
|
574 |
+
total_problem += NP
|
575 |
+
|
576 |
+
if dist.is_initialized():
|
577 |
+
data = [total_cost1, total_cost2, total_feasible, total_problem]
|
578 |
+
data = torch.tensor(data, device=self.device)
|
579 |
+
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
580 |
+
total_cost1, total_cost2, total_feasible, total_problem = data.tolist()
|
581 |
+
|
582 |
+
avg_cost1 = total_cost1 / total_problem
|
583 |
+
avg_cost2 = total_cost2 / total_problem
|
584 |
+
avg_feasible = total_feasible / total_problem
|
585 |
+
|
586 |
+
return avg_cost1, avg_cost2, avg_feasible
|
587 |
+
|
588 |
+
def wrap_dataset(self, dataset, workers, buffers, seed):
|
589 |
+
if isinstance(dataset, IterableDataset):
|
590 |
+
dataset = WrapIterableDataset(dataset, self)
|
591 |
+
dataset = DataLoader(dataset, batch_size=None, pin_memory=True,
|
592 |
+
num_workers=workers, prefetch_factor=buffers,
|
593 |
+
worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id))
|
594 |
+
else:
|
595 |
+
if self.device.type == 'cuda':
|
596 |
+
with torch.cuda.device(cuda_or_none(self.device)):
|
597 |
+
dataset = WrapDataset(dataset, self)
|
598 |
+
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
|
599 |
+
else:
|
600 |
+
dataset = WrapDataset(dataset, self)
|
601 |
+
dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
|
602 |
+
|
603 |
+
return dataset
|
604 |
+
|
605 |
+
def print_nn_args(self):
|
606 |
+
for key, value in self.nn_args.items():
|
607 |
+
if type(value) in [int, float, str, bool]:
|
608 |
+
print("nn_args: {} = {}".format(key, value))
|
609 |
+
sys.stdout.flush()
|
610 |
+
|
611 |
+
|
612 |
+
def to_tensor(key, value, isbatch):
|
613 |
+
if isinstance(value, torch.Tensor):
|
614 |
+
tensor = value.to('cpu')
|
615 |
+
else:
|
616 |
+
tensor = torch.tensor(value, device='cpu')
|
617 |
+
|
618 |
+
if not isbatch:
|
619 |
+
tensor = tensor[None]
|
620 |
+
|
621 |
+
return tensor
|
622 |
+
|
623 |
+
|
624 |
+
def cuda_or_none(device):
|
625 |
+
return device if device.type == 'cuda' else None
|
greedrl/utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
|
5 |
+
act_dict = {}
|
6 |
+
act_dict['none'] = lambda x: x
|
7 |
+
act_dict['relu'] = torch.relu
|
8 |
+
act_dict['tanh'] = torch.tanh
|
9 |
+
act_dict['sigmoid'] = torch.sigmoid
|
10 |
+
|
11 |
+
|
12 |
+
def get_act(act):
|
13 |
+
return act_dict[act]
|
14 |
+
|
15 |
+
|
16 |
+
def to_list(var):
|
17 |
+
if isinstance(var, dict):
|
18 |
+
return {k: to_list(v) for k, v in var.items()}
|
19 |
+
elif isinstance(var, list):
|
20 |
+
return [to_list(v) for v in var]
|
21 |
+
elif isinstance(var, tuple):
|
22 |
+
return (to_list(v) for v in var)
|
23 |
+
elif isinstance(var, torch.Tensor):
|
24 |
+
return var.tolist()
|
25 |
+
else:
|
26 |
+
return var
|
27 |
+
|
28 |
+
|
29 |
+
def repeat(tensor, size, dim=0):
|
30 |
+
return tensor.repeat_interleave(size, dim)
|
31 |
+
|
32 |
+
|
33 |
+
def get_default_device():
|
34 |
+
if not torch.cuda.is_available():
|
35 |
+
return torch.device("cpu")
|
36 |
+
|
37 |
+
cmd = 'nvidia-smi -q -d Memory | grep -A4 GPU | grep Free'
|
38 |
+
with os.popen(cmd) as result:
|
39 |
+
max_free_mem = 0
|
40 |
+
max_cuda_index = -1
|
41 |
+
for i, line in enumerate(result):
|
42 |
+
free_mem = int(line.strip().split()[2])
|
43 |
+
if free_mem > max_free_mem:
|
44 |
+
max_free_mem = free_mem
|
45 |
+
max_cuda_index = i
|
46 |
+
|
47 |
+
return torch.device("cuda:{}".format(max_cuda_index))
|
48 |
+
|
49 |
+
|
50 |
+
def cumem_stats(device, msg):
|
51 |
+
torch.cuda.empty_cache()
|
52 |
+
print("{}, device:{}, memory_allocated: {:.3f}G".format(msg, device,
|
53 |
+
torch.cuda.memory_allocated(device) / (1024 * 1024 * 1024)))
|
54 |
+
|
55 |
+
|
56 |
+
cutime_stats_time = None
|
57 |
+
|
58 |
+
|
59 |
+
def cutime_stats(device, msg=''):
|
60 |
+
global cutime_stats_time
|
61 |
+
torch.cuda.synchronize(device)
|
62 |
+
if cutime_stats_time is not None:
|
63 |
+
print("{} time: {:.6f}s".format(msg, time.time() - cutime_stats_time))
|
64 |
+
|
65 |
+
cutime_stats_time = time.time()
|
greedrl/variable.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import functools
|
3 |
+
|
4 |
+
from .utils import repeat
|
5 |
+
|
6 |
+
|
7 |
+
class VarMeta(object):
|
8 |
+
def __init__(self, clazz, **kwargs):
|
9 |
+
self.clazz = clazz
|
10 |
+
self._kwargs = kwargs
|
11 |
+
for k, v in kwargs.items():
|
12 |
+
setattr(self, k, v)
|
13 |
+
|
14 |
+
def __call__(self, problem, batch_size, sample_num):
|
15 |
+
kwargs = self._kwargs.copy()
|
16 |
+
kwargs['problem'] = problem.feats
|
17 |
+
kwargs['batch_size'] = batch_size
|
18 |
+
kwargs['sample_num'] = sample_num
|
19 |
+
kwargs['worker_num'] = problem.worker_num
|
20 |
+
kwargs['task_num'] = problem.task_num
|
21 |
+
return self.clazz(**kwargs)
|
22 |
+
|
23 |
+
|
24 |
+
def attribute_variable(name, attribute=None):
|
25 |
+
return VarMeta(AttributeVariable, name=name, attribute=attribute)
|
26 |
+
|
27 |
+
|
28 |
+
class AttributeVariable:
|
29 |
+
def __init__(self, name, attribute, problem, batch_size, sample_num, worker_num, task_num):
|
30 |
+
if attribute is None:
|
31 |
+
attribute = name;
|
32 |
+
|
33 |
+
self.name = name
|
34 |
+
self.value = problem[attribute]
|
35 |
+
|
36 |
+
|
37 |
+
def feature_variable(name, feature=None):
|
38 |
+
return VarMeta(FeatureVariable, name=name, feature=feature)
|
39 |
+
|
40 |
+
|
41 |
+
class FeatureVariable:
|
42 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
43 |
+
if feature is None:
|
44 |
+
feature = name
|
45 |
+
|
46 |
+
assert feature == 'id' or feature.startswith("worker_") or feature.startswith("task_")
|
47 |
+
|
48 |
+
self.name = name
|
49 |
+
self.feature = problem[feature]
|
50 |
+
self.value = repeat(self.feature, sample_num)
|
51 |
+
|
52 |
+
|
53 |
+
def task_variable(name, feature=None):
|
54 |
+
return VarMeta(TaskVariable, name=name, feature=feature)
|
55 |
+
|
56 |
+
|
57 |
+
class TaskVariable:
|
58 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
59 |
+
if feature is None:
|
60 |
+
feature = name
|
61 |
+
|
62 |
+
assert feature.startswith("task_")
|
63 |
+
|
64 |
+
self.name = name
|
65 |
+
self.feature = problem[feature]
|
66 |
+
|
67 |
+
size = list(self.feature.size())
|
68 |
+
size[0] = batch_size
|
69 |
+
del size[1]
|
70 |
+
self.value = self.feature.new_zeros(size)
|
71 |
+
|
72 |
+
def step_task(self, b_index, p_index, t_index):
|
73 |
+
self.value[b_index] = self.feature[p_index, t_index]
|
74 |
+
|
75 |
+
|
76 |
+
def worker_variable(name, feature=None):
|
77 |
+
return VarMeta(WorkerVariable, name=name, feature=feature)
|
78 |
+
|
79 |
+
|
80 |
+
class WorkerVariable:
|
81 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
82 |
+
if feature is None:
|
83 |
+
feature = name
|
84 |
+
|
85 |
+
assert feature.startswith("worker_")
|
86 |
+
|
87 |
+
self.name = name
|
88 |
+
self.feature = problem[feature]
|
89 |
+
|
90 |
+
size = list(self.feature.size())
|
91 |
+
size[0] = batch_size
|
92 |
+
del size[1]
|
93 |
+
self.value = self.feature.new_zeros(size)
|
94 |
+
|
95 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
96 |
+
self.value[b_index] = self.feature[p_index, w_index]
|
97 |
+
|
98 |
+
|
99 |
+
def worker_task_variable(name, feature=None):
|
100 |
+
return VarMeta(WorkerTaskVariable, name=name, feature=feature)
|
101 |
+
|
102 |
+
|
103 |
+
class WorkerTaskVariable:
|
104 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
105 |
+
if feature is None:
|
106 |
+
feature = name
|
107 |
+
|
108 |
+
assert feature.startswith("worker_task_")
|
109 |
+
|
110 |
+
self.name = name
|
111 |
+
self.feature = problem[feature]
|
112 |
+
|
113 |
+
size = list(self.feature.size())
|
114 |
+
size[0] = batch_size
|
115 |
+
|
116 |
+
del size[1]
|
117 |
+
self._feature = self.feature.new_zeros(size)
|
118 |
+
|
119 |
+
del size[2]
|
120 |
+
self.value = self.feature.new_zeros(size)
|
121 |
+
|
122 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
123 |
+
self._feature[b_index] = self.feature[p_index, w_index]
|
124 |
+
|
125 |
+
def step_task(self, b_index, p_index, t_index):
|
126 |
+
self.value[b_index] = self._feature[b_index, t_index]
|
127 |
+
|
128 |
+
|
129 |
+
def worker_task_group(name, feature=None):
|
130 |
+
return VarMeta(WorkerTaskGroup, name=name, feature=feature)
|
131 |
+
|
132 |
+
|
133 |
+
class WorkerTaskGroup:
|
134 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
135 |
+
if feature is None:
|
136 |
+
feature = name
|
137 |
+
|
138 |
+
assert feature.startswith("task_")
|
139 |
+
|
140 |
+
self.name = name
|
141 |
+
self.feature = problem[feature].long()
|
142 |
+
|
143 |
+
NG = self.feature.max() + 1
|
144 |
+
assert torch.all(self.feature >= 0)
|
145 |
+
|
146 |
+
self.value = self.feature.new_zeros(batch_size, NG)
|
147 |
+
|
148 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
149 |
+
self.value[b_index] = 0
|
150 |
+
|
151 |
+
def step_task(self, b_index, p_index, t_index):
|
152 |
+
group = self.feature[p_index, t_index]
|
153 |
+
self.value[b_index, group] += 1;
|
154 |
+
|
155 |
+
|
156 |
+
def worker_task_item(name, item_id, item_num):
|
157 |
+
return VarMeta(WorkerTaskItem, name=name, item_id=item_id, item_num=item_num)
|
158 |
+
|
159 |
+
|
160 |
+
class WorkerTaskItem:
|
161 |
+
def __init__(self, name, item_id, item_num, problem, batch_size, sample_num, worker_num, task_num):
|
162 |
+
assert item_id.startswith('task_')
|
163 |
+
assert item_num.startswith('task_')
|
164 |
+
|
165 |
+
self.name = name
|
166 |
+
self.item_id = repeat(problem[item_id], sample_num).long()
|
167 |
+
self.item_num = repeat(problem[item_num], sample_num)
|
168 |
+
|
169 |
+
assert torch.all(self.item_id >= 0)
|
170 |
+
|
171 |
+
size = [0, 0]
|
172 |
+
size[0] = self.item_id.size(0)
|
173 |
+
size[1] = self.item_id.max() + 1
|
174 |
+
self.value = self.item_num.new_zeros(size)
|
175 |
+
|
176 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
177 |
+
self.value[b_index] = 0
|
178 |
+
|
179 |
+
def step_task(self, b_index, p_index, t_index):
|
180 |
+
item_id = self.item_id[b_index, t_index]
|
181 |
+
item_num = self.item_num[b_index, t_index]
|
182 |
+
self.value[b_index[:, None], item_id] += item_num
|
183 |
+
|
184 |
+
def make_feat(self):
|
185 |
+
NT = self.item_id.size(1)
|
186 |
+
v = self.value[:, None, :]
|
187 |
+
v = v.expand(-1, NT, -1)
|
188 |
+
|
189 |
+
v = v.gather(2, self.item_id).clamp(0, 1)
|
190 |
+
v = self.item_num.clamp(0, 1) - v
|
191 |
+
return v.clamp(0, 1).sum(2)
|
192 |
+
|
193 |
+
|
194 |
+
def task_demand_now(name, feature=None, only_this=False):
|
195 |
+
return VarMeta(TaskDemandNow, name=name, feature=feature, only_this=only_this)
|
196 |
+
|
197 |
+
|
198 |
+
class TaskDemandNow:
|
199 |
+
def __init__(self, name, feature, only_this, problem, batch_size, sample_num, worker_num, task_num):
|
200 |
+
|
201 |
+
if feature is None:
|
202 |
+
feature = name
|
203 |
+
|
204 |
+
assert feature.startswith("task_")
|
205 |
+
|
206 |
+
self.name = name
|
207 |
+
self.only_this = only_this
|
208 |
+
self._value = repeat(problem[feature], sample_num)
|
209 |
+
|
210 |
+
assert self._value.dtype in \
|
211 |
+
(torch.int8, torch.int16, torch.int32, torch.int64)
|
212 |
+
assert torch.all(self._value >= 0)
|
213 |
+
|
214 |
+
if only_this:
|
215 |
+
size = self._value.size(0)
|
216 |
+
self.value = self._value.new_zeros(size)
|
217 |
+
else:
|
218 |
+
self.value = self._value
|
219 |
+
|
220 |
+
def step_task(self, b_index, p_index, t_index, done):
|
221 |
+
if done is not None:
|
222 |
+
self._value[b_index, t_index] -= done
|
223 |
+
|
224 |
+
if self.only_this:
|
225 |
+
self.value[b_index] = self._value[b_index, t_index]
|
226 |
+
else:
|
227 |
+
self.value = self._value
|
228 |
+
|
229 |
+
|
230 |
+
def worker_count_now(name, feature=None):
|
231 |
+
return VarMeta(WorkerCountNow, name=name, feature=feature)
|
232 |
+
|
233 |
+
|
234 |
+
class WorkerCountNow:
|
235 |
+
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
|
236 |
+
if feature is None:
|
237 |
+
feature = name
|
238 |
+
|
239 |
+
assert feature.startswith("worker_")
|
240 |
+
|
241 |
+
self.name = name
|
242 |
+
self.value = repeat(problem[feature], sample_num)
|
243 |
+
|
244 |
+
assert self.value.dtype in \
|
245 |
+
(torch.int8, torch.int16, torch.int32, torch.int64)
|
246 |
+
assert torch.all(self.value >= 0)
|
247 |
+
|
248 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
249 |
+
self.value[b_index, w_index] -= 1
|
250 |
+
|
251 |
+
|
252 |
+
def edge_variable(name, feature, last_to_this=False,
|
253 |
+
this_to_task=False, task_to_end=False, last_to_loop=False):
|
254 |
+
return VarMeta(EdgeVariable, name=name, feature=feature,
|
255 |
+
last_to_this=last_to_this, this_to_task=this_to_task, task_to_end=task_to_end,
|
256 |
+
last_to_loop=last_to_loop)
|
257 |
+
|
258 |
+
|
259 |
+
class EdgeVariable:
|
260 |
+
def __init__(self, name, feature, last_to_this, this_to_task, task_to_end, last_to_loop,
|
261 |
+
problem, batch_size, sample_num, worker_num, task_num):
|
262 |
+
|
263 |
+
assert feature.endswith("_matrix")
|
264 |
+
|
265 |
+
flags = [last_to_this, this_to_task, task_to_end, last_to_loop]
|
266 |
+
assert flags.count(True) == 1 and flags.count(False) == 3
|
267 |
+
|
268 |
+
if feature is None:
|
269 |
+
feature = name
|
270 |
+
|
271 |
+
self.name = name
|
272 |
+
self.last_to_this = last_to_this
|
273 |
+
self.this_to_task = this_to_task
|
274 |
+
self.task_to_end = task_to_end
|
275 |
+
self.last_to_loop = last_to_loop
|
276 |
+
|
277 |
+
self.worker_num = worker_num
|
278 |
+
self.task_num = task_num
|
279 |
+
|
280 |
+
self.feature = problem[feature]
|
281 |
+
|
282 |
+
size = list(self.feature.size())
|
283 |
+
size[0] = batch_size
|
284 |
+
del size[1:3]
|
285 |
+
|
286 |
+
if self.this_to_task or self.task_to_end:
|
287 |
+
size.insert(1, task_num)
|
288 |
+
self.value = self.feature.new_zeros(size)
|
289 |
+
else:
|
290 |
+
self.value = self.feature.new_zeros(size)
|
291 |
+
|
292 |
+
self.end_index = self.feature.new_zeros(size[0], dtype=torch.int64)
|
293 |
+
self.loop_index = self.feature.new_zeros(size[0], dtype=torch.int64)
|
294 |
+
self.last_index = self.feature.new_zeros(size[0], dtype=torch.int64)
|
295 |
+
self.task_index = (torch.arange(task_num) + worker_num * 2)[None, :]
|
296 |
+
|
297 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
298 |
+
if self.last_to_this:
|
299 |
+
self.value[b_index] = 0
|
300 |
+
self.last_index[b_index] = w_index
|
301 |
+
elif self.this_to_task:
|
302 |
+
self.do_this_to_task(b_index, p_index, w_index)
|
303 |
+
elif self.task_to_end:
|
304 |
+
self.end_index[b_index] = w_index + self.worker_num
|
305 |
+
self.do_task_to_end(b_index, p_index)
|
306 |
+
elif self.last_to_loop:
|
307 |
+
self.value[b_index] = 0
|
308 |
+
self.last_index[b_index] = w_index
|
309 |
+
|
310 |
+
def step_worker_end(self, b_index, p_index, w_index):
|
311 |
+
this_index = w_index + self.worker_num
|
312 |
+
if self.last_to_this:
|
313 |
+
self.do_last_to_this(b_index, p_index, this_index)
|
314 |
+
elif self.this_to_task:
|
315 |
+
self.do_this_to_task(b_index, p_index, this_index)
|
316 |
+
elif self.task_to_end:
|
317 |
+
pass
|
318 |
+
elif self.last_to_loop:
|
319 |
+
self.do_last_to_loop(b_index, p_index)
|
320 |
+
|
321 |
+
def step_task(self, b_index, p_index, t_index):
|
322 |
+
this_index = t_index + self.worker_num * 2
|
323 |
+
if self.last_to_this:
|
324 |
+
self.do_last_to_this(b_index, p_index, this_index)
|
325 |
+
self.last_index[b_index] = this_index
|
326 |
+
elif self.this_to_task:
|
327 |
+
self.do_this_to_task(b_index, p_index, this_index)
|
328 |
+
elif self.task_to_end:
|
329 |
+
pass
|
330 |
+
elif self.last_to_loop:
|
331 |
+
last_index = self.last_index[b_index]
|
332 |
+
loop_index = self.loop_index[b_index]
|
333 |
+
self.loop_index[b_index] = torch.where(last_index < self.worker_num, this_index, loop_index)
|
334 |
+
self.last_index[b_index] = this_index
|
335 |
+
|
336 |
+
def do_last_to_this(self, b_index, p_index, this_index):
|
337 |
+
last_index = self.last_index[b_index]
|
338 |
+
self.value[b_index] = self.feature[p_index, last_index, this_index]
|
339 |
+
|
340 |
+
def do_this_to_task(self, b_index, p_index, this_index):
|
341 |
+
p_index2 = p_index[:, None]
|
342 |
+
this_index2 = this_index[:, None]
|
343 |
+
task_index2 = self.task_index
|
344 |
+
self.value[b_index] = self.feature[p_index2, this_index2, task_index2]
|
345 |
+
|
346 |
+
def do_task_to_end(self, b_index, p_index):
|
347 |
+
p_index2 = p_index[:, None]
|
348 |
+
task_index2 = self.task_index
|
349 |
+
end_index = self.end_index[b_index]
|
350 |
+
end_index2 = end_index[:, None]
|
351 |
+
self.value[b_index] = self.feature[p_index2, task_index2, end_index2]
|
352 |
+
|
353 |
+
def do_last_to_loop(self, b_index, p_index):
|
354 |
+
loop_index = self.loop_index[b_index]
|
355 |
+
last_index = self.last_index[b_index]
|
356 |
+
self.value[b_index] = self.feature[p_index, last_index, loop_index]
|
357 |
+
|
358 |
+
def make_feat(self):
|
359 |
+
assert self.this_to_task or self.task_to_end, \
|
360 |
+
"one of [this_to_task, task_to_end] must be true"
|
361 |
+
return self.value.clone()
|
362 |
+
|
363 |
+
|
364 |
+
def worker_used_resource(name, edge_require=None, task_require=None, task_ready=None, worker_ready=None, task_due=None):
|
365 |
+
return VarMeta(WorkerUsedResource, name=name, edge_require=edge_require, task_require=task_require,
|
366 |
+
task_ready=task_ready, worker_ready=worker_ready, task_due=task_due)
|
367 |
+
|
368 |
+
|
369 |
+
class WorkerUsedResource:
|
370 |
+
def __init__(self, name, edge_require, task_require, task_ready, worker_ready, task_due,
|
371 |
+
problem, batch_size, sample_num, worker_num, task_num):
|
372 |
+
|
373 |
+
assert edge_require is None or edge_require.endswith("_matrix"), "unsupported edge: {}".format(edge_require)
|
374 |
+
assert task_require is None or task_require.startswith("task_"), "unsupported task_require: {}".format(
|
375 |
+
task_require)
|
376 |
+
assert task_ready is None or task_ready.startswith("task_"), "unsupported task_service: {}".format(task_ready)
|
377 |
+
assert worker_ready is None or worker_ready.startswith("worker_") and not worker_ready.startswith(
|
378 |
+
"worker_task_")
|
379 |
+
assert task_due is None or task_due.startswith("task_"), "unsupported task_due: {}".format(task_due)
|
380 |
+
|
381 |
+
self.name = name
|
382 |
+
|
383 |
+
self.worker_num = worker_num
|
384 |
+
self.task_num = task_num
|
385 |
+
|
386 |
+
if edge_require is None:
|
387 |
+
self.edge_require = None
|
388 |
+
else:
|
389 |
+
self.edge_require = problem[edge_require]
|
390 |
+
self.last_index = self.edge_require.new_zeros(batch_size, dtype=torch.int64)
|
391 |
+
|
392 |
+
if task_require is None:
|
393 |
+
self.task_require = None
|
394 |
+
else:
|
395 |
+
self.task_require = problem[task_require]
|
396 |
+
self.task_require2 = repeat(self.task_require, sample_num)
|
397 |
+
|
398 |
+
if task_ready is None:
|
399 |
+
self.task_ready = None
|
400 |
+
else:
|
401 |
+
self.task_ready = problem[task_ready]
|
402 |
+
|
403 |
+
if worker_ready is None:
|
404 |
+
self.worker_ready = None
|
405 |
+
else:
|
406 |
+
self.worker_ready = problem[worker_ready]
|
407 |
+
|
408 |
+
if task_due is None:
|
409 |
+
self.task_due = None
|
410 |
+
else:
|
411 |
+
self.task_due = problem[task_due]
|
412 |
+
|
413 |
+
tenors = [self.edge_require, self.task_require, self.task_ready, self.worker_ready]
|
414 |
+
tenors = list(filter(lambda x: x is not None, tenors))
|
415 |
+
assert tenors, "at least one of edge_require, task_require, task_ready, worker_ready is required!"
|
416 |
+
|
417 |
+
size = list(tenors[0].size())
|
418 |
+
size[0] = batch_size
|
419 |
+
if self.edge_require is None:
|
420 |
+
del size[1]
|
421 |
+
else:
|
422 |
+
del size[1:3]
|
423 |
+
|
424 |
+
self.value = tenors[0].new_zeros(size)
|
425 |
+
|
426 |
+
def step_worker_start(self, b_index, p_index, w_index):
|
427 |
+
if self.worker_ready is None:
|
428 |
+
self.value[b_index] = 0
|
429 |
+
else:
|
430 |
+
self.value[b_index] = self.worker_ready[p_index, w_index]
|
431 |
+
|
432 |
+
if self.edge_require is not None:
|
433 |
+
self.last_index[b_index] = w_index
|
434 |
+
|
435 |
+
def step_worker_end(self, b_index, p_index, w_index):
|
436 |
+
if self.edge_require is not None:
|
437 |
+
last_index = self.last_index[b_index]
|
438 |
+
this_index = w_index + self.worker_num
|
439 |
+
self.value[b_index] += self.edge_require[p_index, last_index, this_index]
|
440 |
+
self.last_index[b_index] = this_index;
|
441 |
+
|
442 |
+
def step_task(self, b_index, p_index, t_index, done):
|
443 |
+
if done is None:
|
444 |
+
if self.edge_require is not None:
|
445 |
+
last_index = self.last_index[b_index]
|
446 |
+
this_index = t_index + (self.worker_num * 2)
|
447 |
+
self.value[b_index] += self.edge_require[p_index, last_index, this_index]
|
448 |
+
self.last_index[b_index] = this_index
|
449 |
+
|
450 |
+
if self.task_ready is not None:
|
451 |
+
self.value[b_index] = torch.max(self.value[b_index], self.task_ready[p_index, t_index])
|
452 |
+
|
453 |
+
else:
|
454 |
+
if self.task_require is not None:
|
455 |
+
if self.value.dim() == 2:
|
456 |
+
done = done[:, None]
|
457 |
+
self.value[b_index] += self.task_require[p_index, t_index] * done
|
458 |
+
|
459 |
+
def make_feat(self):
|
460 |
+
assert self.value.dim() == 2, \
|
461 |
+
"value's dim must be 2, actual: {}".format(self.value.dim())
|
462 |
+
assert self.task_require is not None, "task_require is required"
|
463 |
+
|
464 |
+
v = self.value[:, None, :] + self.task_require2
|
465 |
+
return v.clamp(0, 1).sum(2, dtype=v.dtype)
|
466 |
+
|
467 |
+
|
468 |
+
def worker_task_sequence(name):
|
469 |
+
return VarMeta(WorkerTaskSequence, name=name)
|
470 |
+
|
471 |
+
|
472 |
+
class WorkerTaskSequence:
|
473 |
+
def __init__(self, name, problem, batch_size, sample_num, worker_num, task_num):
|
474 |
+
self.name = name
|
475 |
+
self.value = None
|
476 |
+
|
477 |
+
def step_finish(self, worker_task_seq):
|
478 |
+
self.value = worker_task_seq
|
images/GREEDRL-Framwork.png
ADDED
images/GREEDRL-Framwork_en.png
ADDED
images/GREEDRL-Logo-Original-640.png
ADDED
images/GREEDRL-Network.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.1+cu113
|
2 |
+
torchvision==0.13.1+cu113
|
3 |
+
torchaudio==0.12.1
|
4 |
+
numpy==1.24.2
|
5 |
+
Cython==0.29.34
|
6 |
+
ortools==9.6.2534
|
7 |
+
|
setup.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
from distutils import sysconfig
|
8 |
+
from setuptools import setup, Extension, find_packages
|
9 |
+
from Cython.Build import build_ext, cythonize
|
10 |
+
|
11 |
+
|
12 |
+
class CMakeExtension(Extension):
|
13 |
+
def __init__(self, name, sourcedir=''):
|
14 |
+
Extension.__init__(self, name, sources=[])
|
15 |
+
self.sourcedir = os.path.abspath(sourcedir)
|
16 |
+
|
17 |
+
|
18 |
+
class CMakeBuild(build_ext):
|
19 |
+
def build_extension(self, ext):
|
20 |
+
if isinstance(ext, CMakeExtension):
|
21 |
+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
22 |
+
if not extdir.endswith(os.path.sep):
|
23 |
+
extdir += os.path.sep
|
24 |
+
|
25 |
+
if not os.path.exists(self.build_temp):
|
26 |
+
os.makedirs(self.build_temp)
|
27 |
+
|
28 |
+
subprocess.check_call(['cmake', ext.sourcedir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
|
29 |
+
cwd=self.build_temp)
|
30 |
+
subprocess.check_call(['cmake', '--build', '.', '--', 'VERBOSE=1', '-j8'], cwd=self.build_temp)
|
31 |
+
else:
|
32 |
+
super().build_extension(ext)
|
33 |
+
|
34 |
+
|
35 |
+
ext_modules = [CMakeExtension('greedrl_c')]
|
36 |
+
|
37 |
+
setup(
|
38 |
+
name='greedrl',
|
39 |
+
version='1.0.0',
|
40 |
+
packages=find_packages(),
|
41 |
+
ext_modules=ext_modules,
|
42 |
+
cmdclass={'build_ext': CMakeBuild},
|
43 |
+
install_requires=["torch==1.12.1+cu113"],
|
44 |
+
)
|
test/all_test.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from solver_test import *
|
2 |
+
from function_test import *
|
3 |
+
|
4 |
+
if __name__ == '__main__':
|
5 |
+
|
6 |
+
unittest.main()
|
7 |
+
|
test/basetest.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
|
5 |
+
class TestCase(unittest.TestCase):
|
6 |
+
def tearDown(self):
|
7 |
+
torch.cuda.empty_cache()
|
8 |
+
|
test/function_test.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import unittest
|
5 |
+
import basetest
|
6 |
+
|
7 |
+
from greedrl import Solver
|
8 |
+
from greedrl.function import *
|
9 |
+
|
10 |
+
device = Solver().device
|
11 |
+
|
12 |
+
|
13 |
+
class TestFunction(basetest.TestCase):
|
14 |
+
|
15 |
+
def test_task_group_split(self):
|
16 |
+
group = torch.ones((8, 8), dtype=torch.int32)
|
17 |
+
group[:, 0:4] = 0
|
18 |
+
value = torch.zeros((8, 8), dtype=torch.bool)
|
19 |
+
value[:, 0:4] = True
|
20 |
+
result = task_group_split(group, value)
|
21 |
+
assert not torch.any(result)
|
22 |
+
|
23 |
+
value[:, 0:2] = False
|
24 |
+
result = task_group_split(group, value)
|
25 |
+
assert torch.all(result)
|
26 |
+
|
27 |
+
def test_task_group_split2(self):
|
28 |
+
group = torch.randint(48, (1024, 1000), dtype=torch.int32)
|
29 |
+
value = torch.randint(2, (1024, 1000), dtype=torch.int8) <= 0
|
30 |
+
self.do_test(task_group_split, group, value)
|
31 |
+
|
32 |
+
def test_task_group_priority(self):
|
33 |
+
group = torch.ones((8, 8), dtype=torch.int32)
|
34 |
+
group[:, 0:4] = 0
|
35 |
+
priority = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32)
|
36 |
+
priority = priority[None, :].expand(8, -1).clone()
|
37 |
+
value = torch.zeros((8, 8), dtype=torch.bool)
|
38 |
+
value[:, 4:6] = True
|
39 |
+
|
40 |
+
result = task_group_priority(group, priority, value)
|
41 |
+
expected = torch.tensor([False, True, True, True, True, True, False, True])
|
42 |
+
expected = expected[None, :].expand(8, -1)
|
43 |
+
assert torch.all(result == expected)
|
44 |
+
|
45 |
+
def test_task_group_priority2(self):
|
46 |
+
group = torch.randint(48, (1024, 1000), dtype=torch.int32)
|
47 |
+
value = torch.randint(2, (1024, 1000), dtype=torch.int8) < 1
|
48 |
+
priority = torch.randint(2, (1024, 1000), dtype=torch.int32)
|
49 |
+
self.do_test(task_group_priority, group, priority, value)
|
50 |
+
|
51 |
+
def do_test(self, function, *args):
|
52 |
+
print("\ntest {} ...".format(function.__name__))
|
53 |
+
start = time.time()
|
54 |
+
result1 = function(*args)
|
55 |
+
print("time: {:.6f}s, device: {}".format(time.time() - start, args[0].device))
|
56 |
+
|
57 |
+
args = [arg.to(device) for arg in args]
|
58 |
+
result1 = result1.to(device)
|
59 |
+
|
60 |
+
function(*args)
|
61 |
+
self.sync_device(device)
|
62 |
+
|
63 |
+
start = time.time()
|
64 |
+
result2 = function(*args)
|
65 |
+
self.sync_device(device)
|
66 |
+
print("time: {:.6f}s, device: {} ".format(time.time() - start, args[0].device))
|
67 |
+
|
68 |
+
if result1.is_floating_point():
|
69 |
+
assert torch.all(torch.abs(result1 - result2) < 1e-6)
|
70 |
+
else:
|
71 |
+
assert torch.all(result1 == result2)
|
72 |
+
|
73 |
+
def sync_device(self, device):
|
74 |
+
if device.type == 'cuda':
|
75 |
+
torch.cuda.synchronize(device)
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
unittest.main()
|
test/solver_test.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import basetest
|
7 |
+
from greedrl import Solver
|
8 |
+
from greedrl.const import *
|
9 |
+
|
10 |
+
sys.path.append(osp.join(osp.dirname(osp.abspath(__file__)), "../"))
|
11 |
+
from examples.cvrp import cvrp
|
12 |
+
|
13 |
+
|
14 |
+
class TestSolver(basetest.TestCase):
|
15 |
+
def test(self):
|
16 |
+
problem_list = cvrp.make_problem(1)
|
17 |
+
|
18 |
+
nn_args = {}
|
19 |
+
nn_args['decode_rnn'] = 'GRU'
|
20 |
+
solver = Solver(None, nn_args)
|
21 |
+
|
22 |
+
solver.train(None, problem_list, problem_list,
|
23 |
+
batch_size=32, max_steps=5, memopt=10)
|
24 |
+
|
25 |
+
solver.train(None, problem_list, problem_list,
|
26 |
+
batch_size=32, max_steps=5, memopt=10, topk_size=10)
|
27 |
+
|
28 |
+
solver.train(None, problem_list, problem_list,
|
29 |
+
batch_size=32, max_steps=5, memopt=10, on_policy=False)
|
30 |
+
|
31 |
+
solution = solver.solve(problem_list[0], batch_size=8)
|
32 |
+
assert torch.all(solution.worker_task_sequence[:, -1, 0] == GRL_FINISH)
|
33 |
+
problem_list[0].solution = solution.worker_task_sequence[:, 0:-1, :]
|
34 |
+
|
35 |
+
solution2 = solver.solve(problem_list[0], batch_size=1)
|
36 |
+
assert torch.all(solution.worker_task_sequence == solution2.worker_task_sequence)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
unittest.main()
|