Spaces:
Sleeping
Sleeping
daisuke.kikuta
commited on
Commit
•
719d0db
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +7 -0
- .streamlit/config.toml +2 -0
- Dockerfile +38 -0
- LICENSE +125 -0
- README.md +45 -0
- analyze_dataset.py +71 -0
- app.py +455 -0
- eval_classifier.py +241 -0
- eval_solvers.py +62 -0
- generate_cf_dataset.py +145 -0
- generate_dataset.py +117 -0
- install_solvers.py +68 -0
- models/cf_generator.py +155 -0
- models/classifiers/general_classifier.py +60 -0
- models/classifiers/ground_truth/ground_truth.py +35 -0
- models/classifiers/ground_truth/ground_truth_base.py +285 -0
- models/classifiers/ground_truth/ground_truth_cvrp.py +14 -0
- models/classifiers/ground_truth/ground_truth_cvrptw.py +15 -0
- models/classifiers/ground_truth/ground_truth_pctsp.py +16 -0
- models/classifiers/ground_truth/ground_truth_pctsptw.py +15 -0
- models/classifiers/ground_truth/ground_truth_tsptw.py +14 -0
- models/classifiers/meaningless_models.py +62 -0
- models/classifiers/nn_classifiers/attention_graph_encoder.py +95 -0
- models/classifiers/nn_classifiers/decoders/lstm_decoder.py +57 -0
- models/classifiers/nn_classifiers/decoders/mha_decoder.py +74 -0
- models/classifiers/nn_classifiers/decoders/mlp_decoder.py +50 -0
- models/classifiers/nn_classifiers/encoders/attn_edge_encoder.py +81 -0
- models/classifiers/nn_classifiers/encoders/concat_edge_encoder.py +63 -0
- models/classifiers/nn_classifiers/encoders/max_readout.py +57 -0
- models/classifiers/nn_classifiers/encoders/mean_readout.py +57 -0
- models/classifiers/nn_classifiers/encoders/mha_node_encoder.py +63 -0
- models/classifiers/nn_classifiers/encoders/mlp_node_encoder.py +63 -0
- models/classifiers/nn_classifiers/nn_classifier.py +156 -0
- models/classifiers/predictor.py +203 -0
- models/classifiers/rule_based_models.py +150 -0
- models/loss_functions.py +129 -0
- models/prompts/generate_explanation.py +110 -0
- models/prompts/identify_question.py +76 -0
- models/prompts/template_json_base.py +27 -0
- models/route_explainer.py +293 -0
- models/solvers/concorde/concorde.py +127 -0
- models/solvers/concorde/concorde_utils.py +93 -0
- models/solvers/general_solver.py +44 -0
- models/solvers/lkh/lkh.py +21 -0
- models/solvers/lkh/lkh_base.py +157 -0
- models/solvers/lkh/lkh_cvrp.py +41 -0
- models/solvers/lkh/lkh_cvrptw.py +59 -0
- models/solvers/lkh/lkh_tsp.py +15 -0
- models/solvers/lkh/lkh_tsptw.py +32 -0
- models/solvers/ortools/ortools.py +58 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.png
|
3 |
+
lkh_io_files/
|
4 |
+
concorde_io_files/
|
5 |
+
data/
|
6 |
+
data_generator/
|
7 |
+
checkpoints/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
enableStaticServing = true
|
Dockerfile
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
2 |
+
|
3 |
+
RUN echo "Building docker image"
|
4 |
+
|
5 |
+
RUN apt-get -y update && \
|
6 |
+
apt-get install -y \
|
7 |
+
curl \
|
8 |
+
build-essential \
|
9 |
+
git \
|
10 |
+
vim \
|
11 |
+
tmux
|
12 |
+
|
13 |
+
# jupyter-notebook & lab
|
14 |
+
RUN python3 -m pip install jupyter
|
15 |
+
RUN python3 -m pip install jupyterlab
|
16 |
+
|
17 |
+
# LLMs
|
18 |
+
RUN python3 -m pip install openai
|
19 |
+
RUN python3 -m pip install tiktoken
|
20 |
+
RUN python3 -m pip install langchain
|
21 |
+
|
22 |
+
# Web app
|
23 |
+
RUN python3 -m pip install streamlit
|
24 |
+
RUN python3 -m pip install streamlit-folium
|
25 |
+
RUN python3 -m pip install folium
|
26 |
+
|
27 |
+
# Google Map API
|
28 |
+
RUN python3 -m pip install googlemaps
|
29 |
+
|
30 |
+
# OR-tools
|
31 |
+
RUN python3 -m pip install ortools
|
32 |
+
|
33 |
+
# other convenient packages
|
34 |
+
RUN python3 -m pip install torchmetrics
|
35 |
+
RUN python3 -m pip install scipy
|
36 |
+
RUN python3 -m pip install pandas
|
37 |
+
RUN python3 -m pip install matplotlib
|
38 |
+
RUN python3 -m pip install tqdm
|
LICENSE
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SOFTWARE LICENSE AGREEMENT FOR EVALUATION
|
2 |
+
|
3 |
+
This SOFTWARE LICENSE AGREEMENT FOR EVALUATION (this "Agreement") is a legal contract between a person
|
4 |
+
who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT").
|
5 |
+
|
6 |
+
READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR
|
7 |
+
OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS
|
8 |
+
AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER
|
9 |
+
THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE
|
10 |
+
SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER
|
11 |
+
UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY
|
12 |
+
TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD
|
13 |
+
TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR
|
14 |
+
USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE
|
15 |
+
ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE.
|
16 |
+
|
17 |
+
|
18 |
+
BACKGROUND
|
19 |
+
A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and
|
20 |
+
related documentation except OSS listed in Exhibit A to this Agreement.
|
21 |
+
|
22 |
+
B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such
|
23 |
+
a license to User, pursuant and subject to the terms and conditions of this Agreement.
|
24 |
+
|
25 |
+
C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement.
|
26 |
+
|
27 |
+
In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows:
|
28 |
+
|
29 |
+
1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and
|
30 |
+
conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the
|
31 |
+
non-commercial purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper
|
32 |
+
submitted by NTT to a certain academy or technical contest, etc. ("academy"). User may make a reasonable number of
|
33 |
+
backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1.
|
34 |
+
|
35 |
+
2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User
|
36 |
+
shall be solely responsible for proper installation of the Software.
|
37 |
+
|
38 |
+
3. Term. This Agreement is effective whichever is earlier (i) upon User's acceptance of the Agreement, or (ii) upon User's
|
39 |
+
installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to
|
40 |
+
any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any
|
41 |
+
of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the
|
42 |
+
research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies
|
43 |
+
it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User's
|
44 |
+
decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this
|
45 |
+
Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof,
|
46 |
+
or to destroy all such materials and provide written verification of such destruction to NTT.
|
47 |
+
|
48 |
+
4. Proprietary Rights
|
49 |
+
(a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to
|
50 |
+
this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges
|
51 |
+
that all patent rights, copyrights and trade secret rights in the Software except OSS shall remain the exclusive property of
|
52 |
+
NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software.
|
53 |
+
|
54 |
+
(b) NTT shall not be subject to the obligation of licensing the copyright, patent rights, etc. of author when user hope
|
55 |
+
commercial / noncommercial use of the published / provided software, etc.
|
56 |
+
|
57 |
+
(c) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT:
|
58 |
+
(i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY;
|
59 |
+
(ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER;
|
60 |
+
(iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT;
|
61 |
+
(iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE;
|
62 |
+
OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE.
|
63 |
+
|
64 |
+
(d) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted
|
65 |
+
under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied.
|
66 |
+
|
67 |
+
5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage,
|
68 |
+
or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE
|
69 |
+
RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE.
|
70 |
+
|
71 |
+
6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ONTHE PART OF NTT.
|
72 |
+
NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY,
|
73 |
+
OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES.
|
74 |
+
USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS,
|
75 |
+
AND UTILITY IN A PRODUCTION ENVIRONMENT.
|
76 |
+
|
77 |
+
7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL,
|
78 |
+
OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS,
|
79 |
+
ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES
|
80 |
+
PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE.
|
81 |
+
THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT,
|
82 |
+
INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION
|
83 |
+
PURSUANT TO SECTION 3.
|
84 |
+
|
85 |
+
8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned,
|
86 |
+
or otherwise transferred by User without NTT's prior written consent.
|
87 |
+
|
88 |
+
9. OSS. The OSS included in the software is shown on the "OSS List" in Exhibit A.
|
89 |
+
User shall be subject to the license term of each OSS, when User uses the software.
|
90 |
+
|
91 |
+
10. General
|
92 |
+
(d) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by
|
93 |
+
operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this
|
94 |
+
Agreement shall remain in full force and effect.
|
95 |
+
|
96 |
+
(e) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the
|
97 |
+
subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the
|
98 |
+
parties relating to that subject matter.
|
99 |
+
|
100 |
+
(f) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and
|
101 |
+
assigns of NTT and User.
|
102 |
+
|
103 |
+
(g) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this
|
104 |
+
Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not
|
105 |
+
as damages, its attorneys' fees and other costs associated with such action or proceeding.
|
106 |
+
|
107 |
+
(h) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles.
|
108 |
+
All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo
|
109 |
+
in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association.
|
110 |
+
The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final
|
111 |
+
and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof.
|
112 |
+
|
113 |
+
(f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT's obligation set
|
114 |
+
forth under this Agreement due to any cause beyond NTT's reasonable control.
|
115 |
+
|
116 |
+
EXHIBIT A
|
117 |
+
- Software
|
118 |
+
N/A
|
119 |
+
|
120 |
+
- OSS List
|
121 |
+
+----+---------+-----+
|
122 |
+
| No | License | OSS |
|
123 |
+
+----+---------+-----+
|
124 |
+
| 1 | N/A | N/A |
|
125 |
+
+----+---------+-----+
|
README.md
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RouteExplainer: An Explanation Framework for Vehicle Routing Problem
|
2 |
+
This repo is the official implementation of "RouteExplainer: An Explanation Framework for Vehicle Routing Problem" (PAKDD 2024). Please check more details at the project page https://ntt-dkiku.github.io/xai-vrp/.
|
3 |
+
|
4 |
+
## Setup
|
5 |
+
We recommend using Docker to setup development environments. Please use the [Dockerfile](./Dockerfile) in this repository.
|
6 |
+
```
|
7 |
+
docker build -t route_explainer/route_explainer:1.0 .
|
8 |
+
```
|
9 |
+
If you use LKH and Concorde, you need to install them by typing the following command. LKH and Concorde is required for reproducing experiments, but not for demo.
|
10 |
+
```
|
11 |
+
python install_solvers.py
|
12 |
+
```
|
13 |
+
In the following, all commands are supposed to be typed inside the Docker container.
|
14 |
+
|
15 |
+
## Reproducibility
|
16 |
+
<!-- Refer to [reproduce_experiments.ipynb](./reproduct_experiments.ipynb). -->
|
17 |
+
Coming Soon!
|
18 |
+
|
19 |
+
## Training and evaluating edge classifiers
|
20 |
+
### Generating synthetic data with labels
|
21 |
+
```
|
22 |
+
python generate_dataset.py --problem tsptw --annotation --parallel
|
23 |
+
```
|
24 |
+
|
25 |
+
### Training
|
26 |
+
```
|
27 |
+
python train.py
|
28 |
+
```
|
29 |
+
|
30 |
+
### Evaluation
|
31 |
+
```
|
32 |
+
python eval.py
|
33 |
+
```
|
34 |
+
|
35 |
+
## Explanation generation (demo)
|
36 |
+
Go to http://localhost:8888 after launching the streamlit app with the following command. You may change the port number as you like.
|
37 |
+
```
|
38 |
+
streamlit run app.py --server.port 8888
|
39 |
+
```
|
40 |
+
|
41 |
+
## Licence
|
42 |
+
Our code is licenced by NTT. Basically, the use of our code is limitted to research purposes. See [LICENSE](./LICENSE) for more details.
|
43 |
+
|
44 |
+
## Citation
|
45 |
+
Coming Soon!
|
analyze_dataset.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.cm as cm
|
4 |
+
from utils.util_data import load_dataset
|
5 |
+
|
6 |
+
|
7 |
+
def get_cmap(num_colors):
|
8 |
+
if num_colors <= 10:
|
9 |
+
cm_name = "tab10"
|
10 |
+
elif num_colors <= 20:
|
11 |
+
cm_name = "tab20"
|
12 |
+
else:
|
13 |
+
assert False
|
14 |
+
return cm.get_cmap(cm_name)
|
15 |
+
|
16 |
+
def analyze_dataset(dataset_path, output_dir):
|
17 |
+
dataset = load_dataset(dataset_path)
|
18 |
+
|
19 |
+
#-----------------------------
|
20 |
+
# Stepwise frequency analysis
|
21 |
+
#-----------------------------
|
22 |
+
max_steps = len(dataset[0][0]) # num_nodes
|
23 |
+
num_labels = 2
|
24 |
+
freq = [[] for _ in range(num_labels)]
|
25 |
+
weights = [[] for _ in range(num_labels)]
|
26 |
+
for instance in dataset:
|
27 |
+
labels = instance[-1]
|
28 |
+
for step, label in labels:
|
29 |
+
freq[label].append(step)
|
30 |
+
# visualize histogram
|
31 |
+
fig = plt.figure(figsize=(10, 10))
|
32 |
+
binwidth = 1
|
33 |
+
bins = np.arange(0, max_steps + binwidth, binwidth)
|
34 |
+
cmap = get_cmap(num_labels)
|
35 |
+
for i in range(len(weights)):
|
36 |
+
weights[i] = np.ones(len(freq[i])) / len(dataset)
|
37 |
+
plt.hist(freq[i], bins=bins, alpha=0.5, weights=weights[i], ec=cmap(i), color=cmap(i), label="prioritizing tour length", align="left")
|
38 |
+
plt.xlabel("Steps")
|
39 |
+
plt.ylabel("Frequency (density)")
|
40 |
+
if max_steps <= 20:
|
41 |
+
plt.xticks(np.arange(0, max_steps+1, 1))
|
42 |
+
plt.title(f"# of samples = {len(dataset)}\n# of nodes = {max_steps}")
|
43 |
+
plt.legend()
|
44 |
+
plt.savefig(f"{output_dir}/hist.png", dpi=150, bbox_inches="tight")
|
45 |
+
|
46 |
+
#-----------------------------
|
47 |
+
# Overall ratio of each class
|
48 |
+
#-----------------------------
|
49 |
+
total = np.sum([len(freq[i]) for i in range(num_labels)])
|
50 |
+
ratio = np.array([len(freq[i]) for i in range(num_labels)])
|
51 |
+
ratio = ratio / total
|
52 |
+
with open(f"{output_dir}/ratio.dat", "w") as f:
|
53 |
+
for i in range(len(ratio)):
|
54 |
+
print(f"label{i}, {ratio[i]}", file=f)
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
import argparse
|
58 |
+
import os
|
59 |
+
parser = argparse.ArgumentParser(description='')
|
60 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
61 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
if args.output_dir is None:
|
65 |
+
dataset_dir = os.path.split(args.dataset_path)[0]
|
66 |
+
output_dir = dataset_dir
|
67 |
+
else:
|
68 |
+
output_dir = args.output_dir
|
69 |
+
output_dir += "/analysis"
|
70 |
+
os.makedirs(output_dir, exist_ok=True)
|
71 |
+
analyze_dataset(args.dataset_path, output_dir)
|
app.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# standard modules
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import datetime
|
5 |
+
from PIL import Image
|
6 |
+
from typing import List, Union
|
7 |
+
|
8 |
+
# useful modules ("pip install" is required)
|
9 |
+
import numpy as np
|
10 |
+
import streamlit as st
|
11 |
+
import pandas as pd
|
12 |
+
import googlemaps
|
13 |
+
import langchain
|
14 |
+
from langchain.globals import set_verbose
|
15 |
+
from langchain.chat_models import ChatOpenAI
|
16 |
+
from langchain.schema import HumanMessage, AIMessage
|
17 |
+
|
18 |
+
# our defined modules
|
19 |
+
import utils.util_app as util_app
|
20 |
+
from models.solvers.general_solver import GeneralSolver
|
21 |
+
from models.cf_generator import CFTourGenerator
|
22 |
+
from models.classifiers.general_classifier import GeneralClassifier
|
23 |
+
from models.route_explainer import RouteExplainer
|
24 |
+
|
25 |
+
# general setting
|
26 |
+
SEED = 1234
|
27 |
+
TOUR_NAME = "static/kyoto_tour"
|
28 |
+
TOUR_PATH = TOUR_NAME + ".csv"
|
29 |
+
TOUR_LATLNG_PATH = TOUR_NAME + "_latlng.csv"
|
30 |
+
TOUR_DISTMAT_PATH = TOUR_NAME + "_distmat.pkl"
|
31 |
+
EXPANDED = False
|
32 |
+
DEBUG = True
|
33 |
+
ROUTE_EXPLAINER_ICON = np.array(Image.open("static/route_explainer_icon.png"))
|
34 |
+
|
35 |
+
# for debug
|
36 |
+
if DEBUG:
|
37 |
+
langchain.debug = True
|
38 |
+
set_verbose(True)
|
39 |
+
|
40 |
+
def load_tour_list():
|
41 |
+
# get lat/lng
|
42 |
+
if os.path.isfile(TOUR_LATLNG_PATH):
|
43 |
+
df_tour = pd.read_csv(TOUR_LATLNG_PATH)
|
44 |
+
else:
|
45 |
+
df_tour = pd.read_csv(TOUR_PATH)
|
46 |
+
if googleapi_key := st.session_state.googleapi_key:
|
47 |
+
gmaps = googlemaps.Client(key=googleapi_key)
|
48 |
+
lat_list =[]; lng_list = []
|
49 |
+
for destination in df_tour["destination"]:
|
50 |
+
geo_result = gmaps.geocode(destination)
|
51 |
+
lat_list.append(geo_result[0]["geometry"]["location"]["lat"])
|
52 |
+
lng_list.append(geo_result[0]["geometry"]["location"]["lng"])
|
53 |
+
# add lat/lng
|
54 |
+
df_tour["lat"] = lat_list
|
55 |
+
df_tour["lng"] = lng_list
|
56 |
+
df_tour.to_csv(TOUR_LATLNG_PATH)
|
57 |
+
|
58 |
+
# get the central point
|
59 |
+
st.session_state.lat_mean = np.mean(df_tour["lat"])
|
60 |
+
st.session_state.lng_mean = np.mean(df_tour["lng"])
|
61 |
+
st.session_state.sw = df_tour[["lat", "lng"]].min().tolist()
|
62 |
+
st.session_state.ne = df_tour[["lat", "lng"]].max().tolist()
|
63 |
+
st.session_state.df_tour = df_tour
|
64 |
+
|
65 |
+
# get the distance matrix
|
66 |
+
if os.path.isfile(TOUR_DISTMAT_PATH):
|
67 |
+
with open(TOUR_DISTMAT_PATH, "rb") as f:
|
68 |
+
distmat = pickle.load(f)
|
69 |
+
else:
|
70 |
+
if googleapi_key := st.session_state.googleapi_key:
|
71 |
+
gmaps = googlemaps.Client(key=googleapi_key)
|
72 |
+
distmat = []
|
73 |
+
for origin in df_tour["destination"]:
|
74 |
+
distrow = []
|
75 |
+
for dest in df_tour["destination"]:
|
76 |
+
if origin != dest:
|
77 |
+
dist_result = gmaps.distance_matrix(origin, dest, mode="driving")
|
78 |
+
distrow.append(dist_result["rows"][0]["elements"][0]["duration"]["value"]) # unit: seconds
|
79 |
+
else:
|
80 |
+
distrow.append(0)
|
81 |
+
distmat.append(distrow)
|
82 |
+
distmat = np.array(distmat)
|
83 |
+
with open(TOUR_DISTMAT_PATH, "wb") as f:
|
84 |
+
pickle.dump(distmat, f)
|
85 |
+
|
86 |
+
# input features
|
87 |
+
def convert_clock2seconds(clock):
|
88 |
+
return sum([a*b for a, b in zip([3600, 60], map(int, clock.split(':')))])
|
89 |
+
time_windows = []
|
90 |
+
for i in range(len(df_tour)):
|
91 |
+
time_windows.append([convert_clock2seconds(df_tour["open"][i]),
|
92 |
+
convert_clock2seconds(df_tour["close"][i])])
|
93 |
+
time_windows = np.array(time_windows)
|
94 |
+
time_windows -= time_windows[0, 0]
|
95 |
+
node_feats = {
|
96 |
+
"time_window": time_windows.clip(0),
|
97 |
+
"service_time": df_tour["stay_duration (h)"].to_numpy() * 3600
|
98 |
+
}
|
99 |
+
st.session_state.node_feats = node_feats
|
100 |
+
st.session_state.dist_matrix = distmat
|
101 |
+
st.session_state.node_info = {
|
102 |
+
"open": df_tour["open"],
|
103 |
+
"close": df_tour["close"],
|
104 |
+
"stay": df_tour["stay_duration (h)"]
|
105 |
+
}
|
106 |
+
|
107 |
+
# tour list
|
108 |
+
if os.path.isfile(TOUR_DISTMAT_PATH) & os.path.isfile(TOUR_LATLNG_PATH):
|
109 |
+
st.session_state.tour_list = []
|
110 |
+
for i in range(len(df_tour)):
|
111 |
+
st.session_state.tour_list.append({
|
112 |
+
"name": df_tour["destination"][i],
|
113 |
+
"latlng": (df_tour["lat"][i], df_tour["lng"][i]),
|
114 |
+
"description": f"<font color='silver'>Hours: {df_tour['open'][i]} - {df_tour['close'][i]}<br>Duration of stay: {df_tour['stay_duration (h)'][i]}h<br>Remarks: {df_tour['remarks'][i]}</font>"
|
115 |
+
})
|
116 |
+
|
117 |
+
def solve_vrp() -> None:
|
118 |
+
if ("node_feats" in st.session_state) and ("dist_matrix" in st.session_state):
|
119 |
+
solver = GeneralSolver("tsptw", "ortools", scaling=False)
|
120 |
+
classifier = GeneralClassifier("tsptw", "gt(ortools)")
|
121 |
+
routes = solver.solve(node_feats=st.session_state.node_feats,
|
122 |
+
dist_matrix=st.session_state.dist_matrix)
|
123 |
+
inputs = classifier.get_inputs(routes,
|
124 |
+
0,
|
125 |
+
st.session_state.node_feats,
|
126 |
+
st.session_state.dist_matrix)
|
127 |
+
labels = classifier(inputs)
|
128 |
+
st.session_state.routes = routes.copy()
|
129 |
+
st.session_state.labels = labels.copy()
|
130 |
+
st.session_state.generated_actual_route = True
|
131 |
+
|
132 |
+
#----------
|
133 |
+
# LLM
|
134 |
+
#----------
|
135 |
+
def load_route_explainer(llm_type: str) -> None:
|
136 |
+
if st.session_state.openai_key:
|
137 |
+
# define llm
|
138 |
+
llm = ChatOpenAI(model=llm_type,
|
139 |
+
temperature=0,
|
140 |
+
streaming=True,
|
141 |
+
model_kwargs={"seed": SEED})
|
142 |
+
# model_kwargs={"stop": ["\n\n", "Human"]}
|
143 |
+
|
144 |
+
# define RouteExplainer
|
145 |
+
cf_generator = CFTourGenerator(cf_solver=GeneralSolver("tsptw", "ortools", scaling=False))
|
146 |
+
classifier = GeneralClassifier("tsptw", "gt(ortools)")
|
147 |
+
st.session_state.route_explainer = RouteExplainer(llm=llm,
|
148 |
+
cf_generator=cf_generator,
|
149 |
+
classifier=classifier)
|
150 |
+
|
151 |
+
#----------
|
152 |
+
# UI
|
153 |
+
#----------
|
154 |
+
# css settings
|
155 |
+
st.set_page_config(layout="wide")
|
156 |
+
util_app.apply_responsible_map_css()
|
157 |
+
util_app.apply_centerize_icon_css()
|
158 |
+
util_app.apply_red_code_css()
|
159 |
+
util_app.apply_remove_sidebar_topspace()
|
160 |
+
|
161 |
+
#------------------
|
162 |
+
# side bar setting
|
163 |
+
#------------------
|
164 |
+
with st.sidebar:
|
165 |
+
#-------
|
166 |
+
# Title
|
167 |
+
#-------
|
168 |
+
icon_col, name_col = st.columns((1,10))
|
169 |
+
with icon_col:
|
170 |
+
util_app.apply_html('<a href="https://ntt-dkiku.github.io/xai-vrp" target="_blank"><img src="./app/static/route_explainer_icon.png" alt="RouteExplainer" width="30" height="30" style="margin-top: 20px;"></a>')
|
171 |
+
with name_col:
|
172 |
+
st.title("RouteExplainer")
|
173 |
+
|
174 |
+
#----------
|
175 |
+
# API keys
|
176 |
+
#----------
|
177 |
+
st.subheader("API keys")
|
178 |
+
openai_key_col1, openai_key_col2 = st.columns((1,10))
|
179 |
+
with openai_key_col1:
|
180 |
+
util_app.apply_html('<a href="https://openai.com/blog/openai-api" target="_blank"> <img src="./app/static/openai_logo.png" alt="OpenAI API" width="30" height="30"> </a>')
|
181 |
+
with openai_key_col2:
|
182 |
+
openai_key = st.text_input(label="API keys",
|
183 |
+
key="openai_key",
|
184 |
+
placeholder="OpenAI API key",
|
185 |
+
type="password",
|
186 |
+
label_visibility="collapsed")
|
187 |
+
changed_key = openai_key == os.environ.get('OPENAI_API_KEY')
|
188 |
+
os.environ['OPENAI_API_KEY'] = openai_key
|
189 |
+
|
190 |
+
google_key_col1, google_key_col2 = st.columns((1, 10))
|
191 |
+
with google_key_col1:
|
192 |
+
util_app.apply_html('<a href="https://developers.google.com/maps?hl=en" target="_blank"> <img src="./app/static/googlemap_logo.png" alt="GoogleMap API" width="30" height="30"> </a>')
|
193 |
+
with google_key_col2:
|
194 |
+
st.text_input(label="GoogleMap API key",
|
195 |
+
key="googleapi_key",
|
196 |
+
placeholder="NOT required in this demo",
|
197 |
+
type="password",
|
198 |
+
label_visibility="collapsed")
|
199 |
+
|
200 |
+
#----------------
|
201 |
+
# Foundation LLM
|
202 |
+
#----------------
|
203 |
+
st.subheader("Foundation LLM")
|
204 |
+
llm_type = st.selectbox("LLM", ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo"], key="llm_type", label_visibility="collapsed")
|
205 |
+
|
206 |
+
#-----------
|
207 |
+
# Tour plan
|
208 |
+
#-----------
|
209 |
+
st.subheader("Tour plan")
|
210 |
+
col1, col2 = st.columns((2, 1))
|
211 |
+
with col1:
|
212 |
+
# Comming soon: "Taipei Tour (for PAKDD2024)"
|
213 |
+
tour_plan = st.selectbox("Tour plan", ["Kyoto Tour"], key="tour_type", label_visibility="collapsed")
|
214 |
+
with col2:
|
215 |
+
st.button("Generate", on_click=solve_vrp, use_container_width=True)
|
216 |
+
|
217 |
+
# list destinations
|
218 |
+
load_tour_list()
|
219 |
+
with st.container():
|
220 |
+
if "routes" in st.session_state: # rearranage destinations in the route order if a route was derivied
|
221 |
+
# re-ordered destinations
|
222 |
+
reordered_tour_list = [st.session_state.tour_list[i] for i in st.session_state.routes[0][:-1]] if "routes" in st.session_state else st.session_state.tour_list
|
223 |
+
arr_time = datetime.datetime.strptime(st.session_state.node_info["open"][0], "%H:%M")
|
224 |
+
for step in range(len(reordered_tour_list)):
|
225 |
+
curr = reordered_tour_list[step]
|
226 |
+
next = reordered_tour_list[step+1] if step != len(reordered_tour_list) - 1 else reordered_tour_list[0]
|
227 |
+
curr_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, curr["name"])
|
228 |
+
next_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, next["name"])
|
229 |
+
open_time = datetime.datetime.strptime(st.session_state.node_info["open"][curr_node_id], "%H:%M")
|
230 |
+
# destination info
|
231 |
+
dep_time = max(arr_time, open_time) + datetime.timedelta(hours=st.session_state.node_info["stay"][curr_node_id])
|
232 |
+
dep_time_str = dep_time.strftime("%H:%M")
|
233 |
+
arr_time_str = arr_time.strftime("%H:%M")
|
234 |
+
arr_dep = f"Arr {arr_time_str} - Dep {dep_time_str}" if step != 0 else f"⭐ Dep {dep_time_str}"
|
235 |
+
with st.expander(f"{arr_dep} | {curr['name']}", expanded=EXPANDED):
|
236 |
+
st.write(curr["description"], unsafe_allow_html=True)
|
237 |
+
# travel time
|
238 |
+
travel_time = st.session_state.dist_matrix[curr_node_id][next_node_id].item()
|
239 |
+
col1, col2, col3 = st.columns(3)
|
240 |
+
with col1:
|
241 |
+
st.markdown(f"<center>{util_app.add_time_unit(travel_time)}</center>", unsafe_allow_html=True)
|
242 |
+
with col2:
|
243 |
+
st.markdown("<center>|</center>", unsafe_allow_html=True)
|
244 |
+
st.write("")
|
245 |
+
arr_time = dep_time + datetime.timedelta(seconds=travel_time)
|
246 |
+
# return to the origin
|
247 |
+
destination = reordered_tour_list[0]
|
248 |
+
arr_time_str = arr_time.strftime("%H:%M")
|
249 |
+
with st.expander(f"⭐ Arr {arr_time_str} | {destination['name']}", expanded=EXPANDED):
|
250 |
+
st.write(destination["description"], unsafe_allow_html=True)
|
251 |
+
else: # just list destinations
|
252 |
+
for destination in st.session_state.tour_list:
|
253 |
+
with st.expander(destination['name'], expanded=EXPANDED):
|
254 |
+
st.write(destination["description"], unsafe_allow_html=True)
|
255 |
+
|
256 |
+
#----------------------
|
257 |
+
# state initialization
|
258 |
+
#----------------------
|
259 |
+
if "count" not in st.session_state:
|
260 |
+
st.session_state.count = 0
|
261 |
+
if "chat_history" not in st.session_state:
|
262 |
+
st.session_state.chat_history = []
|
263 |
+
if "generated_actual_route" not in st.session_state:
|
264 |
+
st.session_state.generated_actual_route = False
|
265 |
+
if "generated_cf_route" not in st.session_state:
|
266 |
+
st.session_state.generated_cf_route = False
|
267 |
+
if "curr_route" not in st.session_state:
|
268 |
+
st.session_state.curr_route = "Actual Route" # once the CF route is selected, this will be "Current Route"
|
269 |
+
if "flag_example" not in st.session_state:
|
270 |
+
st.session_state.flag_example = False
|
271 |
+
if "selected_example" not in st.session_state:
|
272 |
+
st.session_state.selected_example = None
|
273 |
+
if "close_chat" not in st.session_state:
|
274 |
+
st.session_state.close_chat = False
|
275 |
+
if "route_explainer" not in st.session_state or llm_type != st.session_state.curr_llm_type or changed_key:
|
276 |
+
load_route_explainer(llm_type)
|
277 |
+
st.session_state.curr_llm_type = llm_type
|
278 |
+
|
279 |
+
#--------------------------------
|
280 |
+
# The following is the main page
|
281 |
+
#--------------------------------
|
282 |
+
|
283 |
+
#----------
|
284 |
+
# Greeding
|
285 |
+
#----------
|
286 |
+
if "routes" not in st.session_state:
|
287 |
+
util_app.apply_html('<center> <img src="./app/static/route_explainer_icon.png" alt="OpenAI API" width="120" height="120"> </center>')
|
288 |
+
greeding = "Hi, I'm RouteExplainer :)<br>Choose a tour and hit the <code>Generate</code> button to generate your initial route!"
|
289 |
+
if st.session_state.count == 0:
|
290 |
+
util_app.stream_words(greeding, prefix="<center><h4>", suffix="</h4></center>", sleep_time=0.02)
|
291 |
+
else:
|
292 |
+
util_app.apply_html(f"<center><h4>{greeding}</h4></center>")
|
293 |
+
|
294 |
+
#--------------
|
295 |
+
# chat history
|
296 |
+
#--------------
|
297 |
+
def find_last_map(lst: List[Union[str, tuple]]) -> int:
|
298 |
+
for i in range(len(lst) - 1, -1, -1):
|
299 |
+
if isinstance(lst[i], tuple):
|
300 |
+
return i
|
301 |
+
return None
|
302 |
+
last_map_idx = find_last_map(st.session_state.chat_history)
|
303 |
+
for i, msg in enumerate(st.session_state.chat_history):
|
304 |
+
if isinstance(msg, tuple): # if the history type is a tuple of maps
|
305 |
+
map1, map2 = (0, 1) if i == last_map_idx else (2, 3)
|
306 |
+
actual_route, cf_route = st.columns(2)
|
307 |
+
if msg[map1] is not None:
|
308 |
+
with actual_route:
|
309 |
+
util_app.visualize_actual_route(msg[map1])
|
310 |
+
if msg[map2] is not None:
|
311 |
+
with cf_route:
|
312 |
+
util_app.visualize_cf_route(msg[map2])
|
313 |
+
else: # if the history type is string
|
314 |
+
if isinstance(msg, AIMessage):
|
315 |
+
st.chat_message(msg.type, avatar=ROUTE_EXPLAINER_ICON).write(msg.content)
|
316 |
+
else:
|
317 |
+
st.chat_message(msg.type).write(msg.content)
|
318 |
+
|
319 |
+
# examples
|
320 |
+
if "cf_routes" not in st.session_state and st.session_state.flag_example:
|
321 |
+
def pickup_example(example: str):
|
322 |
+
st.session_state.selected_example = example
|
323 |
+
|
324 |
+
examples = [
|
325 |
+
"Why do we visit Ginkaku-ji Temple from Fushimi-Inari Shrine and why not Kiyomizu-dera Temple?",
|
326 |
+
"What if we visit Kinkaku-ji directly from Kyoto Geishinkan, instead of Nijo-jo Castle?",
|
327 |
+
"Why was the edge from Kinkaku-ji to Kiyomizu-dera selected and why not the edge from Kinkaku-ji to Hanamikoji Dori?"
|
328 |
+
]
|
329 |
+
col1, col2, col3 = st.columns(3)
|
330 |
+
with col1:
|
331 |
+
st.button(examples[0],
|
332 |
+
use_container_width=True,
|
333 |
+
on_click=pickup_example,
|
334 |
+
args=(examples[0], ))
|
335 |
+
with col2:
|
336 |
+
st.button(examples[1],
|
337 |
+
use_container_width=True,
|
338 |
+
on_click=pickup_example,
|
339 |
+
args=(examples[1], ))
|
340 |
+
with col3:
|
341 |
+
st.button(examples[2],
|
342 |
+
use_container_width=True,
|
343 |
+
on_click=pickup_example,
|
344 |
+
args=(examples[2], ))
|
345 |
+
|
346 |
+
#----------
|
347 |
+
# chat box
|
348 |
+
#----------
|
349 |
+
def answer(prompt: str):
|
350 |
+
st.session_state.chat_history.append(HumanMessage(content=prompt))
|
351 |
+
st.chat_message("user").write(prompt)
|
352 |
+
if os.environ.get('OPENAI_API_KEY') == "":
|
353 |
+
error_msg = "An OpenAI API key has not been set yet :( Please enter a valid key in the side bar!"
|
354 |
+
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
|
355 |
+
st.write(error_msg)
|
356 |
+
st.session_state.chat_history.append(AIMessage(content=error_msg))
|
357 |
+
elif util_app.validate_openai_api_key(os.environ.get('OPENAI_API_KEY')):
|
358 |
+
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
|
359 |
+
explanation = st.session_state.route_explainer.generate_explanation(tour_list=st.session_state.tour_list,
|
360 |
+
whynot_question=prompt,
|
361 |
+
actual_routes=st.session_state.routes,
|
362 |
+
actual_labels=st.session_state.labels,
|
363 |
+
node_feats=st.session_state.node_feats,
|
364 |
+
dist_matrix=st.session_state.dist_matrix)
|
365 |
+
if len(explanation) > 0:
|
366 |
+
st.session_state.chat_history.append(AIMessage(content=explanation))
|
367 |
+
if st.session_state.generated_cf_route:
|
368 |
+
st.rerun()
|
369 |
+
else:
|
370 |
+
error_msg = "The input OpenAI API key appears to be invalid :( Please enter a valid key again in the side bar!"
|
371 |
+
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
|
372 |
+
st.write(error_msg)
|
373 |
+
st.session_state.chat_history.append(AIMessage(content=error_msg))
|
374 |
+
|
375 |
+
if st.session_state.selected_example is not None:
|
376 |
+
example = st.session_state.selected_example
|
377 |
+
st.session_state.selected_example = None
|
378 |
+
answer(example)
|
379 |
+
else:
|
380 |
+
if "routes" in st.session_state and not st.session_state.close_chat:
|
381 |
+
if prompt := st.chat_input(placeholder="Ask a why-not question", key="chat_input"):
|
382 |
+
answer(prompt)
|
383 |
+
|
384 |
+
#---------------------
|
385 |
+
# route visualization
|
386 |
+
#---------------------
|
387 |
+
if "tour_list" in st.session_state: # if tour info is loaded
|
388 |
+
# first message
|
389 |
+
if st.session_state.generated_actual_route: # when an actual route is generated
|
390 |
+
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
|
391 |
+
msg = "Here is your initial route. Please ask me a why and why-not question for a specfic edge!"
|
392 |
+
util_app.stream_words(msg, sleep_time=0.01)
|
393 |
+
st.session_state.flag_example = True
|
394 |
+
st.session_state.chat_history.append(AIMessage(content=msg))
|
395 |
+
|
396 |
+
# visualize the actual & CF routes
|
397 |
+
actual_route, cf_route = st.columns(2)
|
398 |
+
m = None; m2 = None; m_ = None; m2_ = None
|
399 |
+
if st.session_state.generated_actual_route or st.session_state.generated_cf_route:
|
400 |
+
m = util_app.initialize_map() # overwrite m
|
401 |
+
m_ = util_app.initialize_map()
|
402 |
+
if "labels" in st.session_state:
|
403 |
+
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1
|
404 |
+
util_app.vis_route("routes", st.session_state.labels, m, cf_step, "actual")
|
405 |
+
util_app.vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False)
|
406 |
+
with actual_route:
|
407 |
+
util_app.visualize_actual_route(m)
|
408 |
+
if st.session_state.generated_cf_route:
|
409 |
+
m2 = util_app.initialize_map() # overwrite m2
|
410 |
+
m2_ = util_app.initialize_map()
|
411 |
+
if "cf_labels" in st.session_state:
|
412 |
+
util_app.vis_route("cf_routes", st.session_state.cf_labels, m2, st.session_state.cf_step-1, "cf")
|
413 |
+
util_app.vis_route("cf_routes", st.session_state.cf_labels, m2_, st.session_state.cf_step-1, "cf", ant_path=False)
|
414 |
+
with cf_route:
|
415 |
+
util_app.visualize_cf_route(m2)
|
416 |
+
|
417 |
+
# update states related to maps
|
418 |
+
if m is not None:
|
419 |
+
st.session_state.chat_history.append((m, m2, m_, m2_))
|
420 |
+
|
421 |
+
# route selection button
|
422 |
+
if len(st.session_state.chat_history) > 0:
|
423 |
+
last_msg = st.session_state.chat_history[-1]
|
424 |
+
if isinstance(last_msg, tuple):
|
425 |
+
if (last_msg[0] is not None) and (last_msg[1] is not None):
|
426 |
+
col1, col2 = st.columns(2)
|
427 |
+
with col1:
|
428 |
+
st.button("Stay this route", on_click=util_app.select_actual_route)
|
429 |
+
with col2:
|
430 |
+
st.button("Replace with this route", on_click=util_app.select_cf_route)
|
431 |
+
util_app.change_hover_color("button", "Replace with this route", "#1e90ff")
|
432 |
+
|
433 |
+
# for displaying examples
|
434 |
+
if st.session_state.generated_actual_route:
|
435 |
+
st.session_state.generated_actual_route = False
|
436 |
+
st.rerun()
|
437 |
+
|
438 |
+
st.session_state.generated_actual_route = False
|
439 |
+
st.session_state.generated_cf_route = False
|
440 |
+
|
441 |
+
# update session count
|
442 |
+
st.session_state.count += 1
|
443 |
+
|
444 |
+
js = f"""
|
445 |
+
<script>
|
446 |
+
function scroll(dummy_var_to_force_repeat_execution){{
|
447 |
+
var textAreas = parent.document.querySelectorAll('section.main');
|
448 |
+
for (let index = 0; index < textAreas.length; index++) {{
|
449 |
+
textAreas[index].scrollTop = textAreas[index].scrollHeight;
|
450 |
+
}}
|
451 |
+
}}
|
452 |
+
scroll({len(st.session_state.chat_history)})
|
453 |
+
</script>
|
454 |
+
"""
|
455 |
+
st.components.v1.html(js, height=0, width=0)
|
eval_classifier.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import multiprocessing
|
5 |
+
import torch
|
6 |
+
import time
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
|
10 |
+
from utils.util_calc import TemporalConfusionMatrix
|
11 |
+
from models.classifiers.nn_classifiers.nn_classifier import NNClassifier
|
12 |
+
from models.classifiers.ground_truth.ground_truth import GroundTruth
|
13 |
+
from models.classifiers.ground_truth.ground_truth_base import FAIL_FLAG
|
14 |
+
from utils.data_utils.tsptw_dataset import TSPTWDataloader
|
15 |
+
from utils.data_utils.pctsp_dataset import PCTSPDataloader
|
16 |
+
from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader
|
17 |
+
from utils.data_utils.cvrp_dataset import CVRPDataloader
|
18 |
+
from utils.utils import set_device
|
19 |
+
from utils.utils import load_dataset
|
20 |
+
|
21 |
+
def load_eval_dataset(dataset_path, problem, model_type, batch_size, num_workers, parallel, num_cpus):
|
22 |
+
if model_type == "nn":
|
23 |
+
if problem == "tsptw":
|
24 |
+
eval_dataset = TSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
|
25 |
+
elif problem == "pctsp":
|
26 |
+
eval_dataset = PCTSPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
|
27 |
+
elif problem == "pctsptw":
|
28 |
+
eval_dataset = PCTSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
|
29 |
+
elif problem == "cvrp":
|
30 |
+
eval_dataset = CVRPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
#------------
|
35 |
+
# dataloader
|
36 |
+
#------------
|
37 |
+
def pad_seq_length(batch):
|
38 |
+
data = {}
|
39 |
+
for key in batch[0].keys():
|
40 |
+
padding_value = True if key == "mask" else 0.0
|
41 |
+
# post-padding
|
42 |
+
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
|
43 |
+
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
|
44 |
+
data.update({"pad_mask": pad_mask})
|
45 |
+
return data
|
46 |
+
eval_dataloader = DataLoader(eval_dataset,
|
47 |
+
batch_size=batch_size,
|
48 |
+
shuffle=False,
|
49 |
+
collate_fn=pad_seq_length,
|
50 |
+
num_workers=num_workers)
|
51 |
+
return eval_dataloader
|
52 |
+
else:
|
53 |
+
eval_dataset = load_dataset(dataset_path)
|
54 |
+
return eval_dataset
|
55 |
+
|
56 |
+
def eval_classifier(problem: str,
|
57 |
+
dataset,
|
58 |
+
model_type: str,
|
59 |
+
model_dir: str = None,
|
60 |
+
gpu: int = -1,
|
61 |
+
num_workers: int = 4,
|
62 |
+
batch_size: int = 128,
|
63 |
+
parallel: bool = True,
|
64 |
+
solver: str = "ortools",
|
65 |
+
num_cpus: int = 1):
|
66 |
+
#--------------
|
67 |
+
# gpu settings
|
68 |
+
#--------------
|
69 |
+
use_cuda, device = set_device(gpu)
|
70 |
+
|
71 |
+
#-------
|
72 |
+
# model
|
73 |
+
#-------
|
74 |
+
num_classes = 3 if problem == "pctsptw" else 2
|
75 |
+
if model_type == "nn":
|
76 |
+
assert model_dir is not None, "please specify model_path when model_type is nn."
|
77 |
+
params = argparse.ArgumentParser()
|
78 |
+
# model_dir = os.path.split(args.model_path)[0]
|
79 |
+
with open(f"{model_dir}/cmd_args.dat", "r") as f:
|
80 |
+
params.__dict__ = json.load(f)
|
81 |
+
assert params.problem == problem, "problem of the trained model should match that of the dataset"
|
82 |
+
model = NNClassifier(problem=params.problem,
|
83 |
+
node_enc_type=params.node_enc_type,
|
84 |
+
edge_enc_type=params.edge_enc_type,
|
85 |
+
dec_type=params.dec_type,
|
86 |
+
emb_dim=params.emb_dim,
|
87 |
+
num_enc_mlp_layers=params.num_enc_mlp_layers,
|
88 |
+
num_dec_mlp_layers=params.num_dec_mlp_layers,
|
89 |
+
num_classes=num_classes,
|
90 |
+
dropout=params.dropout,
|
91 |
+
pos_encoder=params.pos_encoder)
|
92 |
+
# load trained weights (the best epoch)
|
93 |
+
with open(f"{model_dir}/best_epoch.dat", "r") as f:
|
94 |
+
best_epoch = int(f.read())
|
95 |
+
print(f"loaded {model_dir}/model_epoch{best_epoch}.pth.")
|
96 |
+
model.load_state_dict(torch.load(f"{model_dir}/model_epoch{best_epoch}.pth"))
|
97 |
+
if use_cuda:
|
98 |
+
model.to(device)
|
99 |
+
is_sequential = model.is_sequential
|
100 |
+
elif model_type == "ground_truth":
|
101 |
+
model = GroundTruth(problem=problem, solver_type=solver)
|
102 |
+
is_sequential = False
|
103 |
+
else:
|
104 |
+
assert False, f"Invalid model type: {model_type}"
|
105 |
+
|
106 |
+
#---------
|
107 |
+
# Metrics
|
108 |
+
#---------
|
109 |
+
overall_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
|
110 |
+
eval_accuracy_dict = {} # MulticlassAccuracy(num_classes=num_classes, average="macro")
|
111 |
+
temp_confmat_dict = {} # TemporalConfusionMatrix(num_classes=num_classes, seq_length=50, device=device)
|
112 |
+
temporal_accuracy_dict = {}
|
113 |
+
num_nodes_dist_dict = {}
|
114 |
+
|
115 |
+
#------------
|
116 |
+
# Evaluation
|
117 |
+
#------------
|
118 |
+
if model_type == "nn":
|
119 |
+
model.eval()
|
120 |
+
eval_time = 0.0
|
121 |
+
print("Evaluating models ...", end="")
|
122 |
+
start_time = time.perf_counter()
|
123 |
+
for data in dataset:
|
124 |
+
if use_cuda:
|
125 |
+
data = {key: value.to(device) for key, value in data.items()}
|
126 |
+
if not is_sequential:
|
127 |
+
shp = data["curr_node_id"].size()
|
128 |
+
data = {key: value.flatten(0, 1) for key, value in data.items()}
|
129 |
+
probs = model(data) # [batch_size x num_classes] or [batch_size x max_seq_length x num_classes]
|
130 |
+
if not is_sequential:
|
131 |
+
probs = probs.view(*shp, -1) # [batch_size x max_seq_length x num_classes]
|
132 |
+
data["labels"] = data["labels"].view(*shp)
|
133 |
+
data["pad_mask"] = data["pad_mask"].view(*shp)
|
134 |
+
#------------
|
135 |
+
# evaluation
|
136 |
+
#------------
|
137 |
+
start_eval_time = time.perf_counter()
|
138 |
+
# accuracy
|
139 |
+
seq_length_list = torch.unique(data["pad_mask"].sum(-1))
|
140 |
+
for seq_length_tensor in seq_length_list:
|
141 |
+
seq_length = seq_length_tensor.item()
|
142 |
+
if seq_length not in eval_accuracy_dict.keys():
|
143 |
+
eval_accuracy_dict[seq_length] = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
|
144 |
+
temp_confmat_dict[seq_length] = TemporalConfusionMatrix(num_classes=num_classes, seq_length=seq_length, device=device)
|
145 |
+
temporal_accuracy_dict[seq_length] = [MulticlassF1Score(num_classes=num_classes, average="macro").to(device) for _ in range(seq_length)]
|
146 |
+
num_nodes_dist_dict[seq_length] = 0
|
147 |
+
seq_length_mask = (data["pad_mask"].sum(-1) == seq_length) # [batch_size]
|
148 |
+
extracted_labels = data["labels"][seq_length_mask]
|
149 |
+
extracted_probs = probs[seq_length_mask]
|
150 |
+
extracted_mask = data["pad_mask"][seq_length_mask].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)]
|
151 |
+
eval_accuracy_dict[seq_length](extracted_probs.argmax(-1).view(-1)[extracted_mask], extracted_labels.view(-1)[extracted_mask])
|
152 |
+
mask = data["pad_mask"].view(-1)
|
153 |
+
overall_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask])
|
154 |
+
# confusion matrix
|
155 |
+
temp_confmat_dict[seq_length].update(probs.argmax(-1), data["labels"], data["pad_mask"])
|
156 |
+
# temporal accuracy
|
157 |
+
for step in range(seq_length):
|
158 |
+
temporal_accuracy_dict[seq_length][step](extracted_probs[:, step, :], extracted_labels[:, step])
|
159 |
+
# number of samples whose sequence length is seq_length
|
160 |
+
num_nodes_dist_dict[seq_length] += len(extracted_labels)
|
161 |
+
eval_time += time.perf_counter() - start_eval_time
|
162 |
+
calc_time = time.perf_counter() - start_time - eval_time
|
163 |
+
total_eval_accuracy = {key: value.compute().item() for key, value in eval_accuracy_dict.items()}
|
164 |
+
overall_accuracy = overall_accuracy.compute() #.item()
|
165 |
+
temporal_confmat = {key: value.compute() for key, value in temp_confmat_dict.items()}
|
166 |
+
temporal_accuracy = {key: [value.compute().item() for value in values] for key, values in temporal_accuracy_dict.items()}
|
167 |
+
print("done")
|
168 |
+
return overall_accuracy, total_eval_accuracy, temporal_accuracy, calc_time, temporal_confmat, num_nodes_dist_dict
|
169 |
+
else:
|
170 |
+
eval_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
|
171 |
+
print("Loading data ...", end=" ")
|
172 |
+
with multiprocessing.Pool(num_cpus) as pool:
|
173 |
+
input_list = list(pool.starmap(model.get_inputs, [(instance["tour"], 0, instance) for instance in dataset]))
|
174 |
+
print("done")
|
175 |
+
|
176 |
+
print("Infering labels ...", end="")
|
177 |
+
pool = multiprocessing.Pool(num_cpus)
|
178 |
+
start_time = time.perf_counter()
|
179 |
+
prob_list = list(pool.starmap(model, tqdm([(inputs, False, False) for inputs in input_list])))
|
180 |
+
calc_time = time.perf_counter() - start_time
|
181 |
+
pool.close()
|
182 |
+
print("done")
|
183 |
+
|
184 |
+
print("Evaluating models ...", end="")
|
185 |
+
for i, instance in enumerate(dataset):
|
186 |
+
labels = instance["labels"]
|
187 |
+
for vehicle_id in range(len(labels)):
|
188 |
+
for step, label in labels[vehicle_id]:
|
189 |
+
pred_label = prob_list[i][vehicle_id][step-1] # [num_classes]
|
190 |
+
if pred_label == FAIL_FLAG:
|
191 |
+
pred_label = label - 1 if label != 0 else label + 1
|
192 |
+
eval_accuracy(torch.LongTensor([pred_label]).view(1, -1), torch.LongTensor([label]).view(1, -1))
|
193 |
+
total_eval_accuracy = eval_accuracy.compute()
|
194 |
+
print("done")
|
195 |
+
return total_eval_accuracy.item(), calc_time
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
#-----------------
|
200 |
+
# general settings
|
201 |
+
#-----------------
|
202 |
+
parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu")
|
203 |
+
parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader")
|
204 |
+
parser.add_argument("--parallel", )
|
205 |
+
|
206 |
+
#-------------
|
207 |
+
# data setting
|
208 |
+
#-------------
|
209 |
+
parser.add_argument("--dataset_path", type=str, help="Path to a dataset", required=True)
|
210 |
+
|
211 |
+
#------------------
|
212 |
+
# Metrics settings
|
213 |
+
#------------------
|
214 |
+
|
215 |
+
|
216 |
+
#----------------
|
217 |
+
# model settings
|
218 |
+
#----------------
|
219 |
+
parser.add_argument("--model_type", type=str, default="nn", help="Select from [nn, ground_truth]")
|
220 |
+
# nn classifier
|
221 |
+
parser.add_argument("--model_dir", type=str, default=None)
|
222 |
+
parser.add_argument("--batch_size", type=int, default=256)
|
223 |
+
parser.add_argument("--parallel", action="store_true")
|
224 |
+
# ground truth
|
225 |
+
parser.add_argument("--solver", type=str, default="ortools")
|
226 |
+
parser.add_argument("--num_cpus", type=int, default=os.cpu_count())
|
227 |
+
args = parser.parse_args()
|
228 |
+
|
229 |
+
problem = str(os.path.basename(os.path.dirname(args.dataset_path)))
|
230 |
+
|
231 |
+
dataset = load_eval_dataset(args.dataset_path, problem, args.model_type, args.batch_size, args.num_workers, args.parallel, args.num_cpus)
|
232 |
+
eval_classifier(problem=problem,
|
233 |
+
dataset=dataset,
|
234 |
+
model_type=args.model_type,
|
235 |
+
model_dir=args.model_dir,
|
236 |
+
gpu=args.gpu,
|
237 |
+
num_workers=args.num_workers,
|
238 |
+
batch_size=args.batch_size,
|
239 |
+
parallel=args.parallel,
|
240 |
+
solver=args.solver,
|
241 |
+
num_cpus=args.num_cpus)
|
eval_solvers.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import multiprocessing
|
4 |
+
import numpy as np
|
5 |
+
from utils.utils import load_dataset, calc_tour_length
|
6 |
+
from models.solvers.general_solver import GeneralSolver
|
7 |
+
from models.classifiers.ground_truth.ground_truth import GroundTruth
|
8 |
+
|
9 |
+
def eval_solver(solver, instance):
|
10 |
+
tour = solver.solve(instance)
|
11 |
+
tour_length = calc_tour_length(tour[0], instance["coords"])
|
12 |
+
return tour_length
|
13 |
+
|
14 |
+
def eval(data_path, problem, solver_name, fix_edges, parallel):
|
15 |
+
dataset = load_dataset(data_path)
|
16 |
+
num_cpus = os.cpu_count() if parallel else 1
|
17 |
+
if fix_edges:
|
18 |
+
solver = GroundTruth(problem, solver_name)
|
19 |
+
if parallel:
|
20 |
+
with multiprocessing.Pool(num_cpus) as pool:
|
21 |
+
tours = list(tqdm(pool.starmap(solver.solve, [(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}")
|
22 |
+
for i, instance in enumerate(dataset)
|
23 |
+
for vehicle_id in range(len(instance["tour"]))
|
24 |
+
for step in range(1, len(instance["tour"][vehicle_id]))]), desc=f"Solving {data_path} with {solver_name}"))
|
25 |
+
else:
|
26 |
+
tours = []
|
27 |
+
for i, instance in enumerate(dataset):
|
28 |
+
for vehicle_id in range(len(instance["tour"])):
|
29 |
+
for step in range(1, len(instance["tour"][vehicle_id])):
|
30 |
+
tours.append(solver.solve(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}"))
|
31 |
+
tour_length = {key: [] for key in tours[0].keys()}
|
32 |
+
for tour in tours:
|
33 |
+
for key, value in tour.items():
|
34 |
+
tour_length[key].append(value)
|
35 |
+
else:
|
36 |
+
solver = GeneralSolver(problem, solver_name)
|
37 |
+
with multiprocessing.Pool(num_cpus) as pool:
|
38 |
+
tour_length = list(tqdm(pool.starmap(eval_solver, [(solver, instance) for instance in dataset]), total=len(dataset), desc="Solving instances"))
|
39 |
+
|
40 |
+
feasible_ratio = 0.0
|
41 |
+
penalty = 0.0
|
42 |
+
avg_tour_length = np.mean(tour_length["tsp"])
|
43 |
+
std_tour_length = np.std(tour_length["tsp"])
|
44 |
+
return avg_tour_length, std_tour_length, feasible_ratio, penalty
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
import argparse
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument("--problem", default="tsptw", type=str, help="Problem type: [tsptw, pctsp, pctsptw, cvrp]")
|
50 |
+
parser.add_argument("--solver_name", type=str, default="ortools", help="Select from ")
|
51 |
+
parser.add_argument("--data_path", type=str, help="Path to a dataset", required=True)
|
52 |
+
parser.add_argument("--parallel", action="store_true")
|
53 |
+
parser.add_argument("--all", action="store_true")
|
54 |
+
parser.add_argument("--fix_edges", action="store_true")
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
avg_tour_length, std_tour_length, feasible_ratio, penalty = eval(data_path=args.data_path,
|
58 |
+
problem=args.problem,
|
59 |
+
solver_name=args.solver_name,
|
60 |
+
fix_edges=args.fix_edges,
|
61 |
+
parallel=args.parallel)
|
62 |
+
print(f"tour_length: {avg_tour_length} +/- {std_tour_length}")
|
generate_cf_dataset.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from utils.utils import load_dataset, save_dataset
|
6 |
+
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask, get_cap_mask
|
7 |
+
from models.classifiers.ground_truth.ground_truth import GroundTruth
|
8 |
+
from models.solvers.general_solver import GeneralSolver
|
9 |
+
from models.cf_generator import CFTourGenerator
|
10 |
+
|
11 |
+
|
12 |
+
class CFDatasetBase():
|
13 |
+
def __init__(self, problem, cf_generator, classifier, base_dataset, num_samples, random_seed, parallel, num_cpus):
|
14 |
+
self.problem = problem
|
15 |
+
self.parallel = parallel
|
16 |
+
self.num_cpus = num_cpus
|
17 |
+
self.seed = random_seed
|
18 |
+
self.cf_generator = CFTourGenerator(cf_solver=GeneralSolver(problem, cf_generator))
|
19 |
+
self.classifier = GroundTruth(problem, classifier)
|
20 |
+
self.node_mask = NodeMask(problem)
|
21 |
+
self.dataset = load_dataset(base_dataset)
|
22 |
+
self.num_samples = len(self.dataset) if num_samples is None else num_samples
|
23 |
+
|
24 |
+
def generate_cf_dataset(self):
|
25 |
+
random.seed(self.seed)
|
26 |
+
cf_dataset = []
|
27 |
+
num_required_samples = self.num_samples
|
28 |
+
end = False
|
29 |
+
print("Data generation started.", flush=True)
|
30 |
+
while(not end):
|
31 |
+
dataset = self.dataset[:num_required_samples]
|
32 |
+
self.dataset = np.roll(self.dataset, -num_required_samples)
|
33 |
+
if self.parallel:
|
34 |
+
instances = self.generate_labeldata_para(dataset, self.num_cpus)
|
35 |
+
else:
|
36 |
+
instances = self.generate_labeldata(dataset)
|
37 |
+
cf_dataset.extend(filter(None, instances))
|
38 |
+
num_required_samples = self.num_samples - len(cf_dataset)
|
39 |
+
if num_required_samples == 0:
|
40 |
+
end = True
|
41 |
+
else:
|
42 |
+
print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True)
|
43 |
+
print("Data generation completed.", flush=True)
|
44 |
+
return cf_dataset
|
45 |
+
|
46 |
+
def generate_labeldata(self, dataset):
|
47 |
+
return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")]
|
48 |
+
|
49 |
+
def generate_labeldata_para(self, dataset, num_cpus):
|
50 |
+
with Pool(num_cpus) as pool:
|
51 |
+
annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances"))
|
52 |
+
return annotation_data
|
53 |
+
|
54 |
+
def annotate(self, instance):
|
55 |
+
# generate a counterfactual route randomly
|
56 |
+
routes = instance["tour"]
|
57 |
+
vehicle_id = random.randint(0, len(routes) - 1)
|
58 |
+
if len(routes[vehicle_id]) - 2 <= 2:
|
59 |
+
return
|
60 |
+
cf_step = random.randint(2, len(routes[vehicle_id]) - 2)
|
61 |
+
route = routes[vehicle_id]
|
62 |
+
mask = self.node_mask.get_mask(route, cf_step, instance)
|
63 |
+
node_id = np.arange(len(instance["coords"]))
|
64 |
+
feasible_node_id = node_id[mask]
|
65 |
+
feasible_node_id = feasible_node_id[feasible_node_id != route[cf_step]].tolist()
|
66 |
+
if len(feasible_node_id) == 0:
|
67 |
+
return
|
68 |
+
cf_visit = random.choice(feasible_node_id)
|
69 |
+
cf_routes = self.cf_generator(routes, vehicle_id, cf_step, cf_visit, instance)
|
70 |
+
if cf_routes is None:
|
71 |
+
return
|
72 |
+
|
73 |
+
# annotate each edge
|
74 |
+
inputs = self.classifier.get_inputs(cf_routes, 0, instance)
|
75 |
+
labels = self.classifier(inputs, annotation=True)
|
76 |
+
|
77 |
+
# update tours and lables
|
78 |
+
instance["tour"] = cf_routes
|
79 |
+
instance["labels"] = labels
|
80 |
+
return instance
|
81 |
+
|
82 |
+
class NodeMask():
|
83 |
+
def __init__(self, problem):
|
84 |
+
self.problem = problem
|
85 |
+
|
86 |
+
if self.problem == "tsptw":
|
87 |
+
self.mask_func = get_tsptw_mask
|
88 |
+
elif self.problem == "pctsp":
|
89 |
+
self.mask_func = get_pctsp_mask
|
90 |
+
elif self.problem == "pctsptw":
|
91 |
+
self.mask_func = get_pctsptw_mask
|
92 |
+
elif self.problem == "cvrp":
|
93 |
+
self.mask_func = get_cvrp_mask
|
94 |
+
else:
|
95 |
+
NotImplementedError
|
96 |
+
|
97 |
+
def get_mask(self, route, step, instance):
|
98 |
+
return self.mask_func(route, step, instance)
|
99 |
+
|
100 |
+
def get_tsptw_mask(route, step, instance):
|
101 |
+
visited = get_visited_mask(route, step, instance)
|
102 |
+
not_exceed_tw = get_tw_mask(route, step, instance)
|
103 |
+
return ~visited & not_exceed_tw
|
104 |
+
|
105 |
+
def get_pctsp_mask(route, step, instance):
|
106 |
+
visited = get_visited_mask(route, step, instance)
|
107 |
+
return ~visited
|
108 |
+
|
109 |
+
def get_pctsptw_mask(route, step, instance):
|
110 |
+
visited = get_visited_mask(route, step, instance)
|
111 |
+
not_exceed_tw = get_tw_mask(route, step, instance)
|
112 |
+
return ~visited & not_exceed_tw
|
113 |
+
|
114 |
+
def get_cvrp_mask(route, step, instance):
|
115 |
+
visited = get_visited_mask(route, step, instance)
|
116 |
+
less_than_cap = get_cap_mask(route, step, instance)
|
117 |
+
return ~visited & less_than_cap
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
import os
|
121 |
+
import argparse
|
122 |
+
parser = argparse.ArgumentParser(description='')
|
123 |
+
parser.add_argument("--problem", type=str, default="tsptw")
|
124 |
+
parser.add_argument("--base_dataset", type=str, required=True)
|
125 |
+
parser.add_argument("--cf_generator", type=str, default="ortools")
|
126 |
+
parser.add_argument("--classifier", type=str, default="ortools")
|
127 |
+
parser.add_argument("--num_samples", type=int, default=None)
|
128 |
+
parser.add_argument("--random_seed", type=int, default=1234)
|
129 |
+
parser.add_argument("--parallel", action="store_true")
|
130 |
+
parser.add_argument("--num_cpus", type=int, default=4)
|
131 |
+
parser.add_argument("--output_dir", type=str, default="data")
|
132 |
+
args = parser.parse_args()
|
133 |
+
|
134 |
+
dataset_gen = CFDatasetBase(args.problem,
|
135 |
+
args.cf_generator,
|
136 |
+
args.classifier,
|
137 |
+
args.base_dataset,
|
138 |
+
args.num_samples,
|
139 |
+
args.random_seed,
|
140 |
+
args.parallel,
|
141 |
+
args.num_cpus)
|
142 |
+
cf_dataset = dataset_gen.generate_cf_dataset()
|
143 |
+
|
144 |
+
output_fname = f"{args.output_dir}/{args.problem}/cf_{dataset_gen.num_samples}samples_seed{args.random_seed}_base_{os.path.basename(args.base_dataset)}.pkl"
|
145 |
+
save_dataset(cf_dataset, output_fname)
|
generate_dataset.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.data_utils.tsptw_dataset import TSPTWDataset
|
2 |
+
from utils.data_utils.pctsp_dataset import PCTSPDataset
|
3 |
+
from utils.data_utils.pctsptw_dataset import PCTSPTWDataset
|
4 |
+
from utils.data_utils.cvrp_dataset import CVRPDataset
|
5 |
+
from utils.data_utils.cvrptw_dataset import CVRPTWDataset
|
6 |
+
from utils.utils import save_dataset
|
7 |
+
|
8 |
+
def generate_dataset(num_samples, args):
|
9 |
+
if args.problem == "tsptw":
|
10 |
+
data_generator = TSPTWDataset(coord_dim=args.coord_dim,
|
11 |
+
num_samples=num_samples,
|
12 |
+
num_nodes=args.num_nodes,
|
13 |
+
random_seed=args.random_seed,
|
14 |
+
solver=args.solver,
|
15 |
+
classifier=args.classifier,
|
16 |
+
annotation=args.annotation,
|
17 |
+
parallel=args.parallel,
|
18 |
+
num_cpus=args.num_cpus,
|
19 |
+
distribution=args.distribution)
|
20 |
+
elif args.problem == "pctsp":
|
21 |
+
data_generator = PCTSPDataset(coord_dim=args.coord_dim,
|
22 |
+
num_samples=num_samples,
|
23 |
+
num_nodes=args.num_nodes,
|
24 |
+
random_seed=args.random_seed,
|
25 |
+
solver=args.solver,
|
26 |
+
classifier=args.classifier,
|
27 |
+
annotation=args.annotation,
|
28 |
+
parallel=args.parallel,
|
29 |
+
num_cpus=args.num_cpus,
|
30 |
+
penalty_factor=args.penalty_factor)
|
31 |
+
elif args.problem == "pctsptw":
|
32 |
+
data_generator = PCTSPTWDataset(coord_dim=args.coord_dim,
|
33 |
+
num_samples=num_samples,
|
34 |
+
num_nodes=args.num_nodes,
|
35 |
+
random_seed=args.random_seed,
|
36 |
+
solver=args.solver,
|
37 |
+
classifier=args.classifier,
|
38 |
+
annotation=args.annotation,
|
39 |
+
parallel=args.parallel,
|
40 |
+
num_cpus=args.num_cpus,
|
41 |
+
penalty_factor=args.penalty_factor)
|
42 |
+
elif args.problem == "cvrp":
|
43 |
+
data_generator = CVRPDataset(coord_dim=args.coord_dim,
|
44 |
+
num_samples=num_samples,
|
45 |
+
num_nodes=args.num_nodes,
|
46 |
+
random_seed=args.random_seed,
|
47 |
+
solver=args.solver,
|
48 |
+
classifier=args.classifier,
|
49 |
+
annotation=args.annotation,
|
50 |
+
parallel=args.parallel,
|
51 |
+
num_cpus=args.num_cpus)
|
52 |
+
elif args.problem == "cvrptw":
|
53 |
+
data_generator = CVRPTWDataset(coord_dim=args.coord_dim,
|
54 |
+
num_samples=num_samples,
|
55 |
+
num_nodes=args.num_nodes,
|
56 |
+
random_seed=args.random_seed,
|
57 |
+
solver=args.solver,
|
58 |
+
classifier=args.classifier,
|
59 |
+
annotation=args.annotation,
|
60 |
+
parallel=args.parallel,
|
61 |
+
num_cpus=args.num_cpus)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError
|
64 |
+
|
65 |
+
return data_generator.generate_dataset()
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
import argparse
|
69 |
+
import os
|
70 |
+
import numpy as np
|
71 |
+
parser = argparse.ArgumentParser(description='')
|
72 |
+
# common settings
|
73 |
+
parser.add_argument("--problem", type=str, default="tsptw")
|
74 |
+
parser.add_argument("--random_seed", type=int, default=1234)
|
75 |
+
parser.add_argument("--data_type", type=str, nargs="*", default=["all"], help="data type: 'all' or combo. of ['train', 'valid', 'test'].")
|
76 |
+
parser.add_argument("--num_samples", type=int, nargs="*", default=[1000, 100, 100])
|
77 |
+
parser.add_argument("--num_nodes", type=int, default=20)
|
78 |
+
parser.add_argument("--coord_dim", type=int, default=2, help="only coord_dim=2 is supported for now.")
|
79 |
+
parser.add_argument("--solver", type=str, default="ortools", help="solver that outputs a tour")
|
80 |
+
parser.add_argument("--classifier", type=str, default="ortools", help="classifier for annotation")
|
81 |
+
parser.add_argument("--annotation", action="store_true")
|
82 |
+
parser.add_argument("--parallel", action="store_true")
|
83 |
+
parser.add_argument("--num_cpus", type=int, default=os.cpu_count())
|
84 |
+
parser.add_argument("--output_dir", type=str, default="data")
|
85 |
+
# for TSPTW
|
86 |
+
parser.add_argument("--distribution", type=str, default="da_silva")
|
87 |
+
# for PCTSP
|
88 |
+
parser.add_argument("--penalty_factor", type=float, default=3.)
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
# 3d problems are not supported
|
92 |
+
assert args.coord_dim == 2, "only coord_dim=2 is supported for now."
|
93 |
+
|
94 |
+
# calc num. of total samples (train + valid + test samples)
|
95 |
+
if args.data_type[0] == "all":
|
96 |
+
assert len(args.num_samples) == 3, "please specify # samples for each of the three types (train/valid/test) when you set data_type 'all'. (e.g., --num_samples 1280000 1000 1000)"
|
97 |
+
else:
|
98 |
+
assert len(args.data_type) == len(args.num_samples), "please match # data_types and # elements in num_samples-arg"
|
99 |
+
num_samples = np.sum(args.num_samples)
|
100 |
+
|
101 |
+
# generate a dataset
|
102 |
+
dataset = generate_dataset(num_samples, args)
|
103 |
+
|
104 |
+
# split the dataset
|
105 |
+
if args.data_type[0] == "all":
|
106 |
+
types = ["train", "valid", "eval"]
|
107 |
+
else:
|
108 |
+
types = args.data_type
|
109 |
+
num_sample_list = args.num_samples
|
110 |
+
num_sample_list.insert(0, 0)
|
111 |
+
start = 0
|
112 |
+
for i, type_name in enumerate(types):
|
113 |
+
start += num_sample_list[i]
|
114 |
+
end = start + num_sample_list[i+1]
|
115 |
+
divided_datset = dataset[start:end]
|
116 |
+
output_fname = f"{args.output_dir}/{args.problem}/{type_name}_{args.problem}_{args.num_nodes}nodes_{num_sample_list[i+1]}samples_seed{args.random_seed}.pkl"
|
117 |
+
save_dataset(divided_datset, output_fname)
|
install_solvers.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
def _run(cmd, cwd):
|
6 |
+
subprocess.check_call(cmd, shell=True, cwd=cwd)
|
7 |
+
|
8 |
+
def install_concorde():
|
9 |
+
QSOPT_A_URL = "https://www.math.uwaterloo.ca/~bico/qsopt/beta/codes/PIC/qsopt.PIC.a"
|
10 |
+
QSOPT_H_URL = "https://www.math.uwaterloo.ca/~bico/qsopt/beta/codes/PIC/qsopt.h"
|
11 |
+
CONCORDE_URL = "https://www.math.uwaterloo.ca/tsp/concorde/downloads/codes/src/co031219.tgz"
|
12 |
+
|
13 |
+
concorde_path = "models/solvers/concorde"
|
14 |
+
concorde_src_path = f"{concorde_path}/src"
|
15 |
+
os.makedirs(concorde_src_path, exist_ok=True)
|
16 |
+
# download qsopt, which is a dependency library
|
17 |
+
print("Downloading QSOPT...", end=" ", flush=True)
|
18 |
+
qsopt_path = f"{concorde_src_path}/qsopt"
|
19 |
+
qsopt_a_path = f"{qsopt_path}/qsopt.a"
|
20 |
+
qsopt_h_path = f"{qsopt_path}/qsopt.h"
|
21 |
+
os.makedirs(qsopt_path, exist_ok=True)
|
22 |
+
urllib.request.urlretrieve(QSOPT_A_URL, qsopt_a_path)
|
23 |
+
urllib.request.urlretrieve(QSOPT_H_URL, qsopt_h_path)
|
24 |
+
print("done")
|
25 |
+
|
26 |
+
# download concorde tsp
|
27 |
+
print("Downloading Concorde TSP...", end=" ", flush=True)
|
28 |
+
concorde_tgz_path = f"{concorde_src_path}/concorde.tgz"
|
29 |
+
urllib.request.urlretrieve(CONCORDE_URL, concorde_tgz_path)
|
30 |
+
print("done")
|
31 |
+
|
32 |
+
# build concorde
|
33 |
+
_run("tar -xzf concorde.tgz", concorde_src_path)
|
34 |
+
_run("mv concorde/* .", concorde_src_path)
|
35 |
+
_run("rm -r concorde.tgz concorde", concorde_src_path)
|
36 |
+
cflags = "-fPIC -O2 -g"
|
37 |
+
datadir = os.path.abspath(qsopt_path)
|
38 |
+
cmd = f"CFLAGS='{cflags}' ./configure --prefix {datadir} --with-qsopt={datadir}"
|
39 |
+
_run(cmd, concorde_src_path)
|
40 |
+
_run("make", concorde_src_path)
|
41 |
+
|
42 |
+
def install_lkh():
|
43 |
+
LKH_URL = "http://webhotel4.ruc.dk/~keld/research/LKH-3/LKH-3.0.8.tgz"
|
44 |
+
|
45 |
+
lkh_path = "models/solvers/lkh"
|
46 |
+
lkh_src_path = f"{lkh_path}/src"
|
47 |
+
os.makedirs(lkh_src_path, exist_ok=True)
|
48 |
+
|
49 |
+
# download LKH
|
50 |
+
urllib.request.urlretrieve(LKH_URL, f"{lkh_src_path}/LKH-3.0.8.tgz")
|
51 |
+
|
52 |
+
# build LKH
|
53 |
+
_run("tar -xzf LKH-3.0.8.tgz", lkh_src_path)
|
54 |
+
_run("mv LKH-3.0.8/* .", lkh_src_path)
|
55 |
+
_run("rm -r LKH-3.0.8.tgz LKH-3.0.8", lkh_src_path)
|
56 |
+
_run("make", lkh_src_path)
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
import argparse
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--installed_solvers", default="all", type=str, help="Solvers: [all, concorde, lkh]")
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
if args.installed_solvers == "all" or args.installed_solvers == "concorde":
|
65 |
+
install_concorde()
|
66 |
+
|
67 |
+
if args.installed_solvers == "all" or args.installed_solvers == "lkh":
|
68 |
+
install_lkh()
|
models/cf_generator.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class CFTourGenerator(nn.Module):
|
5 |
+
def __init__(self, cf_solver):
|
6 |
+
super().__init__()
|
7 |
+
self.solver = cf_solver
|
8 |
+
self.problem = cf_solver.problem
|
9 |
+
|
10 |
+
def forward(self, factual_tour, vehicle_id, cf_step, cf_next_node_id, node_feats, dist_matrix=None):
|
11 |
+
"""
|
12 |
+
solve an input instance with visited edges fixed
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
factual_tour: list [seq_length]
|
17 |
+
cf_step: int
|
18 |
+
cf_next_node_id: int
|
19 |
+
node_feats:
|
20 |
+
|
21 |
+
Returns
|
22 |
+
-------
|
23 |
+
cf_tour: np.array [seq_length]
|
24 |
+
"""
|
25 |
+
fixed_paths = self.get_fixed_paths(factual_tour, vehicle_id, cf_step, cf_next_node_id)
|
26 |
+
cf_tours = self.solver.solve(node_feats, fixed_paths, dist_matrix=dist_matrix)
|
27 |
+
if cf_tours is None:
|
28 |
+
return
|
29 |
+
if (cf_step > 0):
|
30 |
+
for vehicle_id, cf_tour in enumerate(cf_tours):
|
31 |
+
if cf_next_node_id in cf_tour:
|
32 |
+
if cf_step == 1:
|
33 |
+
if cf_tour[1] != cf_next_node_id:
|
34 |
+
cf_tours[vehicle_id] = np.flipud(cf_tour)
|
35 |
+
break
|
36 |
+
else:
|
37 |
+
if (factual_tour[vehicle_id][1] != cf_tour[1]):
|
38 |
+
cf_tours[vehicle_id] = np.flipud(cf_tour) # make direction of the cf tour the same as factual one
|
39 |
+
break
|
40 |
+
print("aaaa", cf_tours)
|
41 |
+
return cf_tours
|
42 |
+
|
43 |
+
def get_fixed_paths(self, factual_tour, vehicle_id, cf_step, cf_next_node_id):
|
44 |
+
visited_paths = np.append(factual_tour[vehicle_id][:cf_step], cf_next_node_id)
|
45 |
+
return visited_paths
|
46 |
+
|
47 |
+
# def get_avail_edges(self, factual_tour, cf_step, cf_next_node_id):
|
48 |
+
# visited_paths = np.append(factual_tour[:cf_step], cf_next_node_id)
|
49 |
+
# avail_edges = []
|
50 |
+
# # add fixed edges
|
51 |
+
# for i in range(len(visited_paths) - 1):
|
52 |
+
# avail_edges.append([visited_paths[i], visited_paths[i + 1]])
|
53 |
+
# print(avail_edges)
|
54 |
+
|
55 |
+
# # add rest avaialbel edges
|
56 |
+
# num_nodes = np.max(factual_tour) + 1
|
57 |
+
# visited = np.array([0] * num_nodes)
|
58 |
+
# for id in visited_paths:
|
59 |
+
# visited[id] = 1
|
60 |
+
# visited[factual_tour[0]] = 0
|
61 |
+
# visited[cf_next_node_id] = 0
|
62 |
+
# mask = visited < 1
|
63 |
+
# node_id = np.arange(num_nodes)
|
64 |
+
# feasible_node_id = node_id[mask]
|
65 |
+
# for j in range(len(feasible_node_id) - 1):
|
66 |
+
# for i in range(j + 1, len(feasible_node_id)):
|
67 |
+
# if ((feasible_node_id[j] == factual_tour[0]) and (feasible_node_id[i] == cf_next_node_id)) or ((feasible_node_id[i] == factual_tour[0]) and (feasible_node_id[j] == cf_next_node_id)):
|
68 |
+
# continue
|
69 |
+
# avail_edges.append([feasible_node_id[j], feasible_node_id[i]])
|
70 |
+
# return np.array(avail_edges)
|
71 |
+
|
72 |
+
#-----------
|
73 |
+
# unit test
|
74 |
+
#-----------
|
75 |
+
if __name__ == "__main__":
|
76 |
+
import argparse
|
77 |
+
import random
|
78 |
+
import matplotlib.pyplot as plt
|
79 |
+
# FYI:
|
80 |
+
# - https://yu-nix.com/archives/python-path-get/
|
81 |
+
# - https://www.delftstack.com/ja/howto/python/python-get-parent-directory/
|
82 |
+
# - https://stackoverflow.com/questions/2817264/how-to-get-the-parent-dir-location
|
83 |
+
import os
|
84 |
+
import sys
|
85 |
+
CURR_DIR = os.path.dirname(os.path.abspath(__file__))
|
86 |
+
PARENT_DIR = os.path.abspath(os.path.join(CURR_DIR, os.pardir))
|
87 |
+
sys.path.append(PARENT_DIR)
|
88 |
+
from utils.util_vis import visualize_factual_and_cf_tours
|
89 |
+
from lkh.lkh import LKH
|
90 |
+
from models.ortools.ortools import ORTools
|
91 |
+
from data_generator.tsptw.tsptw_dataset import generate_tsptw_instance
|
92 |
+
|
93 |
+
parser = argparse.ArgumentParser(description='')
|
94 |
+
# general settings
|
95 |
+
parser.add_argument("--problem", type=str, default="tsptw")
|
96 |
+
parser.add_argument("--random_seed", type=int, default=1234)
|
97 |
+
parser.add_argument("--num_samples", type=int, default=5)
|
98 |
+
parser.add_argument("--num_nodes", type=int, default=100)
|
99 |
+
parser.add_argument("--coord_dim", type=int, default=2)
|
100 |
+
# LKH settings
|
101 |
+
parser.add_argument("--max_trials", type=int, default=1000)
|
102 |
+
parser.add_argument("--lkh_dir", type=str, default="lkh", help="Path to the binary of LKH")
|
103 |
+
parser.add_argument("--io_dir", type=str, default="lkh_io_files")
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
# models
|
107 |
+
# cf_solver = LKH(args.problem, args.max_trials, args.random_seed, lkh_dir=args.lkh_dir, io_dir=args.io_dir)
|
108 |
+
cf_solver = ORTools(args.problem)
|
109 |
+
cf_generator = CFTourGenerator(cf_solver)
|
110 |
+
|
111 |
+
# dataset
|
112 |
+
if args.problem == "tsp":
|
113 |
+
np.random.seed(args.random_seed)
|
114 |
+
node_feats = np.random.uniform(size=[args.num_samples, args.num_nodes, args.coord_dim])
|
115 |
+
elif args.problem == "tsptw":
|
116 |
+
coords, time_window, grid_size = generate_tsptw_instance(num_nodes=args.num_nodes, grid_size=100, max_tw_gap=10, max_tw_size=1000, is_integer_instance=True, da_silva_style=True)
|
117 |
+
node_feats = np.concatenate([coords, time_window], -1)
|
118 |
+
node_feats = node_feats[None, :, :]
|
119 |
+
|
120 |
+
# function ot automatically generate couterfactual visit
|
121 |
+
def get_random_cf_visit(factual_tour, random_seed=1234):
|
122 |
+
# random.seed(random_seed)
|
123 |
+
num_nodes = np.max(factual_tour) + 1
|
124 |
+
step = random.randrange(len(factual_tour) - 2) # remove the last step (returning to the start-point)
|
125 |
+
visited = np.array([0] * num_nodes)
|
126 |
+
for i in range(step+1):
|
127 |
+
visited[factual_tour[i]] = 1
|
128 |
+
mask = visited < 1
|
129 |
+
node_id = np.arange(num_nodes)
|
130 |
+
feasible_node_id = node_id[mask]
|
131 |
+
cf_next_id = random.choice(feasible_node_id) # select counterfactual
|
132 |
+
return step, cf_next_id
|
133 |
+
|
134 |
+
for i in range(len(node_feats)):
|
135 |
+
# obtain a factual tour
|
136 |
+
factual_tour = cf_solver.solve(node_feats[i])
|
137 |
+
|
138 |
+
# counterfactual visit
|
139 |
+
cf_step, cf_next_node_id = get_random_cf_visit(factual_tour, random_seed=args.random_seed)
|
140 |
+
|
141 |
+
print(cf_step, cf_next_node_id)
|
142 |
+
# obtain a counterfactual tour
|
143 |
+
cf_tour = cf_generator(factual_tour, cf_step, cf_next_node_id, node_feats[i])
|
144 |
+
|
145 |
+
print(factual_tour)
|
146 |
+
print(cf_tour)
|
147 |
+
|
148 |
+
# visualize the factual and counterfactual tours
|
149 |
+
if args.problem == "tsp":
|
150 |
+
coords = node_feats[i]
|
151 |
+
elif args.problem == "tsptw":
|
152 |
+
coord_dim = 2
|
153 |
+
coords = node_feats[i, :, :coord_dim]
|
154 |
+
visualize_factual_and_cf_tours(factual_tour, cf_tour, coords, cf_step, f"test{i}.png")
|
155 |
+
break
|
models/classifiers/general_classifier.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from models.classifiers.predictor import DecisionPredictor
|
7 |
+
from models.classifiers.meaningless_models import FixedClassPredictor, RandomPredictor
|
8 |
+
from models.classifiers.rule_based_models import kNearestPredictor
|
9 |
+
from models.classifiers.ground_truth.ground_truth import GroundTruth
|
10 |
+
|
11 |
+
class GeneralClassifier(nn.Module):
|
12 |
+
def __init__(self, problem, model_type):
|
13 |
+
super().__init__()
|
14 |
+
self.model_type = model_type
|
15 |
+
self.problem = problem
|
16 |
+
self.model = self.get_model(problem, model_type)
|
17 |
+
|
18 |
+
def change_model(self, problem, model_type):
|
19 |
+
if self.model_type != model_type or self.problem != problem:
|
20 |
+
self.model_type = model_type
|
21 |
+
self.problem = problem
|
22 |
+
self.model = self.get_model(problem, model_type)
|
23 |
+
|
24 |
+
def get_model(self, problem, model_type):
|
25 |
+
if model_type == "gnn":
|
26 |
+
model_path = "checkpoints/model_20230309_101058/model_epoch4.pth"
|
27 |
+
params = argparse.ArgumentParser()
|
28 |
+
model_dir = os.path.split(model_path)[0]
|
29 |
+
with open(f"{model_dir}/cmd_args.dat", "r") as f:
|
30 |
+
params.__dict__ = json.load(f)
|
31 |
+
model = DecisionPredictor(params.problem,
|
32 |
+
params.emb_dim,
|
33 |
+
params.num_mlp_layers,
|
34 |
+
params.num_classes,
|
35 |
+
params.dropout)
|
36 |
+
model.load_state_dict(torch.load(model_path))
|
37 |
+
return model
|
38 |
+
elif model_type == "gt(ortools)":
|
39 |
+
return GroundTruth(problem, solver_type="ortools")
|
40 |
+
elif model_type == "gt(lkh)":
|
41 |
+
return GroundTruth(problem, solver_type="lkh")
|
42 |
+
elif model_type == "gt(concorde)":
|
43 |
+
return GroundTruth(problem, solver_type="concorde")
|
44 |
+
elif model_type == "random":
|
45 |
+
return RandomPredictor(num_classes=2)
|
46 |
+
elif model_type == "fixed":
|
47 |
+
predicted_class = 0
|
48 |
+
return FixedClassPredictor(predicted_class=predicted_class, num_classes=2)
|
49 |
+
elif model_type == "knn":
|
50 |
+
k = 5
|
51 |
+
k_type = "num"
|
52 |
+
return kNearestPredictor(problem, k, k_type)
|
53 |
+
else:
|
54 |
+
assert False, f"Invalid model type: {model_type}"
|
55 |
+
|
56 |
+
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
|
57 |
+
return self.model.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
|
58 |
+
|
59 |
+
def forward(self, inputs):
|
60 |
+
return self.model(inputs)
|
models/classifiers/ground_truth/ground_truth.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW
|
5 |
+
from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP
|
6 |
+
from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW
|
7 |
+
from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP
|
8 |
+
from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW
|
9 |
+
|
10 |
+
class GroundTruth(nn.Module):
|
11 |
+
def __init__(self, problem, solver_type):
|
12 |
+
super().__init__()
|
13 |
+
self.problem = problem
|
14 |
+
self.solver_type = solver_type
|
15 |
+
if problem == "tsptw":
|
16 |
+
self.ground_truth = GroundTruthTSPTW(solver_type)
|
17 |
+
elif problem == "pctsp":
|
18 |
+
self.ground_truth = GroundTruthPCTSP(solver_type)
|
19 |
+
elif problem == "pctsptw":
|
20 |
+
self.ground_truth = GroundTruthPCTSPTW(solver_type)
|
21 |
+
elif problem == "cvrp":
|
22 |
+
self.ground_truth = GroundTruthCVRP(solver_type)
|
23 |
+
elif problem == "cvrptw":
|
24 |
+
self.ground_truth = GroundTruthCVRPTW(solver_type)
|
25 |
+
else:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def forward(self, inputs, annotation=False, parallel=False):
|
29 |
+
return self.ground_truth(inputs, annotation, parallel)
|
30 |
+
|
31 |
+
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
|
32 |
+
return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
|
33 |
+
|
34 |
+
def solve(self, step, input_tour, node_feats, instance_name=None):
|
35 |
+
return self.ground_truth.solve(step, input_tour, node_feats, instance_name)
|
models/classifiers/ground_truth/ground_truth_base.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import multiprocessing
|
6 |
+
from models.solvers.general_solver import GeneralSolver
|
7 |
+
from utils.utils import calc_tour_length
|
8 |
+
|
9 |
+
def get_visited_mask(tour, step, node_feats, dist_matrix=None):
|
10 |
+
"""
|
11 |
+
Visited nodes -> feasible, Unvisited nodes -> infeasible.
|
12 |
+
When solving a problem with visited_paths fixed, they should be included to the solution.
|
13 |
+
Therefore, visited nodes are set to feasible nodes.
|
14 |
+
"""
|
15 |
+
if dist_matrix is not None:
|
16 |
+
num_nodes = len(dist_matrix)
|
17 |
+
else:
|
18 |
+
num_nodes = len(node_feats["coords"])
|
19 |
+
visited = np.isin(np.arange(num_nodes), tour[:step])
|
20 |
+
return visited
|
21 |
+
|
22 |
+
def get_tw_mask(tour, step, node_feats, dist_matrix=None):
|
23 |
+
"""
|
24 |
+
Nodes whose tw exceeds current_time -> infeasible, otherwise -> feasible.
|
25 |
+
|
26 |
+
Parameters
|
27 |
+
----------
|
28 |
+
tour: list [seq_length]
|
29 |
+
step: int
|
30 |
+
node_feats: dict of np.array
|
31 |
+
|
32 |
+
Returns
|
33 |
+
-------
|
34 |
+
mask_tw: np.array [num_nodes]
|
35 |
+
"""
|
36 |
+
node_feats = node_feats.copy()
|
37 |
+
time_window = node_feats["time_window"]
|
38 |
+
if dist_matrix is not None:
|
39 |
+
num_nodes = len(dist_matrix)
|
40 |
+
curr_time = 0.0
|
41 |
+
not_exceed_tw = np.ones(num_nodes).astype(np.int32)
|
42 |
+
for i in range(1, step):
|
43 |
+
prev_id = tour[i - 1]
|
44 |
+
curr_id = tour[i]
|
45 |
+
travel_time = dist_matrix[prev_id, curr_id]
|
46 |
+
# assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}"
|
47 |
+
if curr_time + travel_time < time_window[curr_id, 0]:
|
48 |
+
curr_time = time_window[curr_id, 0].copy()
|
49 |
+
else:
|
50 |
+
curr_time += travel_time
|
51 |
+
curr_time = curr_time + dist_matrix[tour[step-1]] # [num_nodes] TODO: check
|
52 |
+
else:
|
53 |
+
coords = node_feats["coords"]
|
54 |
+
num_nodes = len(coords)
|
55 |
+
curr_time = 0.0
|
56 |
+
not_exceed_tw = np.ones(num_nodes).astype(np.int32)
|
57 |
+
for i in range(1, step):
|
58 |
+
prev_id = tour[i - 1]
|
59 |
+
curr_id = tour[i]
|
60 |
+
travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id])
|
61 |
+
# assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}"
|
62 |
+
if curr_time + travel_time < time_window[curr_id, 0]:
|
63 |
+
curr_time = time_window[curr_id, 0].copy()
|
64 |
+
else:
|
65 |
+
curr_time += travel_time
|
66 |
+
curr_time = curr_time + np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) # [num_nodes] TODO: check
|
67 |
+
not_exceed_tw[curr_time > time_window[:, 1]] = 0
|
68 |
+
not_exceed_tw = not_exceed_tw > 0
|
69 |
+
return not_exceed_tw
|
70 |
+
|
71 |
+
def get_cap_mask(tour, step, node_feats):
|
72 |
+
num_nodes = len(node_feats["coords"])
|
73 |
+
demands = node_feats["demand"]
|
74 |
+
remaining_cap = node_feats["capacity"].copy()
|
75 |
+
less_than_cap = np.ones(num_nodes).astype(np.int32)
|
76 |
+
for i in range(step):
|
77 |
+
remaining_cap -= demands[tour[i]]
|
78 |
+
less_than_cap[remaining_cap < demands] = 0
|
79 |
+
less_than_cap = less_than_cap > 0
|
80 |
+
return less_than_cap
|
81 |
+
|
82 |
+
def get_pc_mask(tour, step, node_feats):
|
83 |
+
"""
|
84 |
+
Mask for Price collecting problems (e.g., PCTSP, PCTSPTW, PCCVRP, PCCVRPTW, ...)
|
85 |
+
|
86 |
+
Returns
|
87 |
+
-------
|
88 |
+
not_exceed_max_length
|
89 |
+
"""
|
90 |
+
large_value = 1e+5
|
91 |
+
coords = node_feats["coords"]
|
92 |
+
max_length = (node_feats["max_length"] * large_value).astype(np.int64)
|
93 |
+
tour_length = 0
|
94 |
+
for i in range(1, step):
|
95 |
+
prev_id = tour[i - 1]
|
96 |
+
curr_id = tour[i]
|
97 |
+
tour_length += (np.linalg.norm(coords[prev_id] - coords[curr_id]) * large_value).astype(np.int64)
|
98 |
+
curr_to_next = (np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes]
|
99 |
+
next_to_depot = (np.linalg.norm(coords[tour[0]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes]
|
100 |
+
not_exceed_max_length = (tour_length + curr_to_next + next_to_depot) <= max_length # [num_nodes]
|
101 |
+
return not_exceed_max_length
|
102 |
+
|
103 |
+
def analyze_tour(tour, node_feats):
|
104 |
+
coords = node_feats["coords"]
|
105 |
+
time_window = node_feats["time_window"]
|
106 |
+
curr_time = 0
|
107 |
+
for i in range(1, len(tour)):
|
108 |
+
prev_id = tour[i - 1]
|
109 |
+
curr_id = tour[i]
|
110 |
+
travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id])
|
111 |
+
valid = curr_time + travel_time < time_window[curr_id, 1]
|
112 |
+
print(f"visit #{i}: {prev_id} -> {curr_id}, travel_time: {travel_time}, arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}, valid: {valid}")
|
113 |
+
if curr_time + travel_time < time_window[curr_id, 0]:
|
114 |
+
curr_time = time_window[curr_id, 0]
|
115 |
+
else:
|
116 |
+
curr_time += travel_time
|
117 |
+
|
118 |
+
FAIL_FLAG = -1
|
119 |
+
class GroundTruthBase(nn.Module):
|
120 |
+
def __init__(self, problem, compared_problems, solver_type):
|
121 |
+
"""
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
|
125 |
+
"""
|
126 |
+
super().__init__()
|
127 |
+
self.problem = problem
|
128 |
+
self.compared_problems = compared_problems
|
129 |
+
self.num_compared_problems = len(compared_problems)
|
130 |
+
self.solver_type = solver_type
|
131 |
+
self.solvers = []
|
132 |
+
for i in range(self.num_compared_problems):
|
133 |
+
# TODO:
|
134 |
+
self.solvers.append(GeneralSolver(self.compared_problems[i], self.solver_type, scaling=False))
|
135 |
+
|
136 |
+
def forward(self, inputs, annotation=False, parallel=True):
|
137 |
+
"""
|
138 |
+
Parameters
|
139 |
+
----------
|
140 |
+
inputs: dict
|
141 |
+
tour: 2d list [num_vehicles x seq_length]
|
142 |
+
first_explained_step: int
|
143 |
+
node_feats: dict of np.array
|
144 |
+
annotation: bool
|
145 |
+
please set it True when annotating data
|
146 |
+
|
147 |
+
Returns
|
148 |
+
-------
|
149 |
+
labels:
|
150 |
+
probs: torch.tensor [batch_size (num_vehicles) x max_seq_length x num_classes]
|
151 |
+
"""
|
152 |
+
input_tours = inputs["tour"]
|
153 |
+
node_feats = inputs["node_feats"]
|
154 |
+
dist_matrix = inputs["dist_matrix"]
|
155 |
+
first_explained_step = inputs["first_explained_step"]
|
156 |
+
num_vehicles = len(input_tours)
|
157 |
+
if annotation:
|
158 |
+
labels = [[] for _ in range(num_vehicles)]
|
159 |
+
for vehicle_id in range(num_vehicles):
|
160 |
+
input_tour = input_tours[vehicle_id]
|
161 |
+
# analyze_tour(input_tour, node_feats)
|
162 |
+
for step in range(first_explained_step + 1, len(input_tour)):
|
163 |
+
_, __, label = self.label_path(vehicle_id, step, input_tour, node_feats)
|
164 |
+
if label == FAIL_FLAG:
|
165 |
+
return
|
166 |
+
labels[vehicle_id].append((step, label))
|
167 |
+
return labels
|
168 |
+
else:
|
169 |
+
if parallel:
|
170 |
+
labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)]
|
171 |
+
num_cpus = os.cpu_count()
|
172 |
+
with multiprocessing.Pool(num_cpus) as pool:
|
173 |
+
for vehicle_id, step, label in pool.starmap(self.label_path, [(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix)
|
174 |
+
for vehicle_id in range(num_vehicles)
|
175 |
+
for step in range(first_explained_step+1, len(input_tours[vehicle_id]))]):
|
176 |
+
labels[vehicle_id][step-(first_explained_step+1)] = label
|
177 |
+
else:
|
178 |
+
labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)]
|
179 |
+
for vehicle_id in range(num_vehicles):
|
180 |
+
for step in range(first_explained_step+1, len(input_tours[vehicle_id])):
|
181 |
+
vehicle_id, step, label = self.label_path(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix)
|
182 |
+
labels[vehicle_id][step-(first_explained_step+1)] = label
|
183 |
+
# validate labels
|
184 |
+
for vehicle_id in range(num_vehicles):
|
185 |
+
assert (len(input_tours[vehicle_id]) - 1) == len(labels[vehicle_id]), f"vehicle_id={vehicle_id}, {input_tours}, {labels}"
|
186 |
+
return labels
|
187 |
+
# labels = [torch.LongTensor(label) for label in labels] # [num_vehicles x seq_length]
|
188 |
+
# labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) # [num_vehicles x max_seq_length]
|
189 |
+
# probs = torch.zeros((labels.size(0), labels.size(1), self.num_compared_problems+1)) # [num_vehicles x max_seq_length x (num_compared_problems+1)]
|
190 |
+
# probs.scatter_(-1, labels.unsqueeze(-1).expand_as(probs), 1.0)
|
191 |
+
# return probs
|
192 |
+
|
193 |
+
def label_path(self, vehicle_id, step, input_tour, node_feats, dist_matrix=None):
|
194 |
+
compared_tour_list = [[] for _ in range(self.num_compared_problems)]
|
195 |
+
visited_path = input_tour[:step].copy()
|
196 |
+
new_node_id, new_node_feats, new_dist_matrix = self.get_feasible_nodes(input_tour, step, node_feats, dist_matrix)
|
197 |
+
new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path)))
|
198 |
+
for i in range(self.num_compared_problems):
|
199 |
+
# TODO: in CVRPTW / PCCVRPTW, need to modify classification of the first and last paths
|
200 |
+
compared_tours = self.solvers[i].solve(new_node_feats, new_visited_path, new_dist_matrix)
|
201 |
+
if compared_tours is None:
|
202 |
+
return vehicle_id, step, FAIL_FLAG
|
203 |
+
compared_tour = None
|
204 |
+
for compared_tour_tmp in compared_tours:
|
205 |
+
if new_visited_path[-1] in compared_tour_tmp:
|
206 |
+
compared_tour = compared_tour_tmp
|
207 |
+
break
|
208 |
+
assert compared_tour is not None, f"Found no appropriate vhiecle. {compared_tours}, {new_visited_path}"
|
209 |
+
compared_tour = np.array(list(map(lambda x: new_node_id[x], compared_tour)))
|
210 |
+
if (step > 0) and (compared_tour[1] != input_tour[1]):
|
211 |
+
compared_tour = np.flipud(compared_tour) # make direction of the cf tour the same as factual one
|
212 |
+
compared_tour_list[i] = compared_tour
|
213 |
+
# print("fixed_paths :", visited_path)
|
214 |
+
# print("input_tour :", input_tour)
|
215 |
+
# print("compared_tour:", compared_tour)
|
216 |
+
# print()
|
217 |
+
# annotation
|
218 |
+
label = self.get_label(input_tour, compared_tour_list, step)
|
219 |
+
return vehicle_id, step, label
|
220 |
+
|
221 |
+
def solve(self, step, input_tour, node_feats, instance_name=None):
|
222 |
+
compared_tours = {}
|
223 |
+
visited_path = input_tour[:step].copy()
|
224 |
+
new_node_id, new_node_feats = self.get_feasible_nodes(input_tour, step, node_feats)
|
225 |
+
new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path)))
|
226 |
+
for i, compared_problem in enumerate(self.compared_problems):
|
227 |
+
compared_tours[compared_problem] = self.solvers[i].solve(new_node_feats, new_visited_path, instance_name)
|
228 |
+
compared_tours[compared_problem] = list(map(lambda compared_tour: list(map(lambda x: new_node_id[x], compared_tour)), compared_tours[compared_problem]))
|
229 |
+
compared_tours[compared_problem] = list(map(lambda compared_tour: calc_tour_length(compared_tour, node_feats["coords"]), compared_tours[compared_problem]))
|
230 |
+
return compared_tours
|
231 |
+
|
232 |
+
def get_label(self, input_tour, compared_tours, step):
|
233 |
+
for i in range(self.num_compared_problems):
|
234 |
+
compared_tour = compared_tours[i]
|
235 |
+
if input_tour[step] == compared_tour[step]:
|
236 |
+
return i
|
237 |
+
return self.num_compared_problems
|
238 |
+
|
239 |
+
def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
|
240 |
+
input_features = {
|
241 |
+
"tour": tour,
|
242 |
+
"first_explained_step": first_explained_step,
|
243 |
+
"node_feats": node_feats,
|
244 |
+
"dist_matrix": dist_matrix
|
245 |
+
}
|
246 |
+
return input_features
|
247 |
+
|
248 |
+
def get_feasible_nodes(self, tour, step, node_feats, dist_matrix=None):
|
249 |
+
"""
|
250 |
+
Parameters
|
251 |
+
----------
|
252 |
+
tour: np.array [seq_length]
|
253 |
+
step: int
|
254 |
+
node_feats: np.array [num_nodes x node_dim]
|
255 |
+
|
256 |
+
Returns
|
257 |
+
-------
|
258 |
+
new_node_id: np.array [num_feasible_nodes]
|
259 |
+
new_node_feats: dict of np.array [num_feasible_nodes x coord_dim]
|
260 |
+
"""
|
261 |
+
if dist_matrix is not None:
|
262 |
+
num_nodes = len(dist_matrix)
|
263 |
+
else:
|
264 |
+
num_nodes = len(node_feats["coords"])
|
265 |
+
mask = self.get_mask(tour, step, node_feats, dist_matrix)
|
266 |
+
node_id = np.arange(num_nodes)
|
267 |
+
new_node_id = node_id[mask].copy()
|
268 |
+
new_node_feats = {
|
269 |
+
key: node_feat[mask].copy()
|
270 |
+
if key in ["coords", "time_window", "demand", "penalties", "prizes"] else
|
271 |
+
node_feat.copy()
|
272 |
+
for key, node_feat in node_feats.items()
|
273 |
+
}
|
274 |
+
if dist_matrix is not None:
|
275 |
+
delete_id = node_id[~mask]
|
276 |
+
new_dist_matrix = np.delete(np.delete(dist_matrix, delete_id, 0), delete_id, 1)
|
277 |
+
else:
|
278 |
+
new_dist_matrix = None
|
279 |
+
return new_node_id, new_node_feats, new_dist_matrix
|
280 |
+
|
281 |
+
def get_mask(self, tour, step, node_feats, dist_matrix=None):
|
282 |
+
raise NotImplementedError
|
283 |
+
|
284 |
+
def check_feasibility(self, tour, node_feats):
|
285 |
+
raise NotImplementedError
|
models/classifiers/ground_truth/ground_truth_cvrp.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
|
2 |
+
from models.classifiers.ground_truth.ground_truth_base import get_cap_mask, get_visited_mask
|
3 |
+
|
4 |
+
class GroundTruthCVRP(GroundTruthBase):
|
5 |
+
def __init__(self, solver_type):
|
6 |
+
problem = "cvrp"
|
7 |
+
compared_problems = ["tsp"]
|
8 |
+
super().__init__(problem, compared_problems, solver_type)
|
9 |
+
|
10 |
+
# @override
|
11 |
+
def get_mask(self, tour, step, node_feats):
|
12 |
+
visited = get_visited_mask(tour, step, node_feats)
|
13 |
+
less_than_cap = get_cap_mask(tour, step, node_feats)
|
14 |
+
return visited | less_than_cap
|
models/classifiers/ground_truth/ground_truth_cvrptw.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
|
2 |
+
from models.classifiers.ground_truth.ground_truth_base import get_cap_mask, get_visited_mask, get_tw_mask
|
3 |
+
|
4 |
+
class GroundTruthCVRPTW(GroundTruthBase):
|
5 |
+
def __init__(self, solver_type):
|
6 |
+
problem = "cvrptw"
|
7 |
+
compared_problems = ["tsp", "cvrp"]
|
8 |
+
super().__init__(problem, compared_problems, solver_type)
|
9 |
+
|
10 |
+
# @override
|
11 |
+
def get_mask(self, tour, step, node_feats):
|
12 |
+
visited = get_visited_mask(tour, step, node_feats)
|
13 |
+
less_than_cap = get_cap_mask(tour, step, node_feats)
|
14 |
+
not_exceed_tw = get_tw_mask(tour, step, node_feats)
|
15 |
+
return visited | (less_than_cap & not_exceed_tw)
|
models/classifiers/ground_truth/ground_truth_pctsp.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
|
3 |
+
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_pc_mask
|
4 |
+
|
5 |
+
class GroundTruthPCTSP(GroundTruthBase):
|
6 |
+
def __init__(self, solver_type):
|
7 |
+
problem = "pctsp"
|
8 |
+
compared_problems = ["tsp"]
|
9 |
+
super().__init__(problem, compared_problems, solver_type)
|
10 |
+
|
11 |
+
# @override
|
12 |
+
def get_mask(self, tour, step, node_feats):
|
13 |
+
# visited = get_visited_mask(tour, step, node_feats)
|
14 |
+
# not_exceed_max_length = get_pc_mask(tour, step, node_feats)
|
15 |
+
num_nodes = len(node_feats["coords"])
|
16 |
+
return np.full(num_nodes, True)
|
models/classifiers/ground_truth/ground_truth_pctsptw.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
|
3 |
+
from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask
|
4 |
+
|
5 |
+
class GroundTruthPCTSPTW(GroundTruthBase):
|
6 |
+
def __init__(self, solver_type):
|
7 |
+
problem = "pctsptw"
|
8 |
+
compared_problems = ["tsp", "pctsp"]
|
9 |
+
super().__init__(problem, compared_problems, solver_type)
|
10 |
+
|
11 |
+
# @override
|
12 |
+
def get_mask(self, tour, step, node_feats):
|
13 |
+
visited = get_visited_mask(tour, step, node_feats)
|
14 |
+
not_exceed_tw = get_tw_mask(tour, step, node_feats)
|
15 |
+
return visited | not_exceed_tw
|
models/classifiers/ground_truth/ground_truth_tsptw.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
|
2 |
+
from models.classifiers.ground_truth.ground_truth_base import get_tw_mask, get_visited_mask
|
3 |
+
|
4 |
+
class GroundTruthTSPTW(GroundTruthBase):
|
5 |
+
def __init__(self, solver_type):
|
6 |
+
problem = "tsptw"
|
7 |
+
compared_problems = ["tsp"]
|
8 |
+
super().__init__(problem, compared_problems, solver_type)
|
9 |
+
|
10 |
+
# @override
|
11 |
+
def get_mask(self, tour, step, node_feats, dist_matrix=None):
|
12 |
+
visited = get_visited_mask(tour, step, node_feats, dist_matrix)
|
13 |
+
not_exceed_tw = get_tw_mask(tour, step, node_feats, dist_matrix)
|
14 |
+
return visited | not_exceed_tw
|
models/classifiers/meaningless_models.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class RandomPredictor(nn.Module):
|
5 |
+
def __init__(self, num_classes):
|
6 |
+
super().__init__()
|
7 |
+
self.num_classes = num_classes
|
8 |
+
|
9 |
+
def forward(self, inputs):
|
10 |
+
"""
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
inputs: int or dict
|
14 |
+
batch_size or dict of input features
|
15 |
+
|
16 |
+
Returns
|
17 |
+
-------
|
18 |
+
probs: torch.tensor [batch_size x num_classes]
|
19 |
+
"""
|
20 |
+
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
|
21 |
+
ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes))
|
22 |
+
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
|
23 |
+
probs.scatter_(-1, ranom_index, 1.0)
|
24 |
+
return probs
|
25 |
+
|
26 |
+
def get_inputs(self, tour, first_explained_step, node_feats):
|
27 |
+
return len(tour[first_explained_step:-1])
|
28 |
+
|
29 |
+
class FixedClassPredictor(nn.Module):
|
30 |
+
def __init__(self, predicted_class, num_classes):
|
31 |
+
"""
|
32 |
+
Paramters
|
33 |
+
---------
|
34 |
+
predicted_class: int
|
35 |
+
a class that this predictor always predicts
|
36 |
+
num_classes: int
|
37 |
+
number of classes
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
self.predicted_class = predicted_class
|
41 |
+
self.num_classes = num_classes
|
42 |
+
assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}."
|
43 |
+
|
44 |
+
def forward(self, inputs):
|
45 |
+
"""
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
inputs: int or dict
|
49 |
+
batch_size or dict of input features
|
50 |
+
|
51 |
+
Returns
|
52 |
+
-------
|
53 |
+
probs: torch.tensor [batch_size x num_classes]
|
54 |
+
"""
|
55 |
+
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
|
56 |
+
index = torch.full((batch_size, self.num_classes), self.predicted_class)
|
57 |
+
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
|
58 |
+
probs.scatter_(-1, index, 1.0)
|
59 |
+
return probs
|
60 |
+
|
61 |
+
def get_inputs(self, tour, first_explained_step, node_feats):
|
62 |
+
return len(tour[first_explained_step:-1])
|
models/classifiers/nn_classifiers/attention_graph_encoder.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class AttentionGraphEncoder(nn.Module):
|
8 |
+
def __init__(self, coord_dim, node_dim, state_dim, emb_dim, dropout):
|
9 |
+
super().__init__()
|
10 |
+
self.coord_dim = coord_dim
|
11 |
+
self.node_dim = node_dim
|
12 |
+
self.emb_dim = emb_dim
|
13 |
+
self.state_dim = state_dim
|
14 |
+
self.norm_factor = 1 / math.sqrt(emb_dim)
|
15 |
+
|
16 |
+
# initial embedding
|
17 |
+
self.init_linear_node = nn.Linear(node_dim, emb_dim)
|
18 |
+
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
|
19 |
+
if state_dim > 0:
|
20 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
21 |
+
|
22 |
+
# An attention layer
|
23 |
+
self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
|
24 |
+
self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
25 |
+
self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
26 |
+
|
27 |
+
# Dropout
|
28 |
+
self.dropout = nn.Dropout(dropout)
|
29 |
+
|
30 |
+
self.reset_parameters()
|
31 |
+
|
32 |
+
def reset_parameters(self):
|
33 |
+
for param in self.parameters():
|
34 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
35 |
+
param.data.uniform_(-stdv, stdv)
|
36 |
+
|
37 |
+
def forward(self, inputs):
|
38 |
+
"""
|
39 |
+
Paramters
|
40 |
+
---------
|
41 |
+
inputs: dict
|
42 |
+
curr_node_id: torch.LongTensor [batch_size x 1]
|
43 |
+
next_node_id: torch.LongTensor [batch_size x 1]
|
44 |
+
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
|
45 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
46 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
47 |
+
|
48 |
+
Returns
|
49 |
+
-------
|
50 |
+
h: torch.tensor [batch_size x emb_dim]
|
51 |
+
graph embeddings
|
52 |
+
"""
|
53 |
+
#----------------
|
54 |
+
# input features
|
55 |
+
#----------------
|
56 |
+
curr_node_id = inputs["curr_node_id"]
|
57 |
+
next_node_id = inputs["next_node_id"]
|
58 |
+
node_feat = inputs["node_feats"]
|
59 |
+
mask = inputs["mask"]
|
60 |
+
state = inputs["state"]
|
61 |
+
|
62 |
+
#---------------------------
|
63 |
+
# initial linear projection
|
64 |
+
#---------------------------
|
65 |
+
node_emb = self.init_linear_node(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
|
66 |
+
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
|
67 |
+
new_node_feat = torch.cat((depot_emb, node_emb), 1) # [batch_size x num_nodes x emb_dim]
|
68 |
+
new_node_feat = self.dropout(new_node_feat)
|
69 |
+
|
70 |
+
#---------------
|
71 |
+
# preprocessing
|
72 |
+
#---------------
|
73 |
+
batch_size = curr_node_id.size(0)
|
74 |
+
curr_emb = new_node_feat.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
75 |
+
next_emb = new_node_feat.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
76 |
+
if state is not None and self.state_dim > 0:
|
77 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
78 |
+
input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
|
79 |
+
else:
|
80 |
+
input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
|
81 |
+
input_kv = torch.cat((curr_emb.expand_as(new_node_feat), new_node_feat), -1) # [batch_size x num_nodes x (2*emb_dim)]
|
82 |
+
|
83 |
+
#--------------------
|
84 |
+
# An attention layer
|
85 |
+
#--------------------
|
86 |
+
q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
|
87 |
+
k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
|
88 |
+
v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
|
89 |
+
compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
|
90 |
+
compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
|
91 |
+
attn = torch.softmax(compatibility, dim=-1)
|
92 |
+
h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
|
93 |
+
h = h.squeeze(1) # [batch_size x emb_dim]
|
94 |
+
|
95 |
+
return h
|
models/classifiers/nn_classifiers/decoders/lstm_decoder.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class LSTMDecoder(nn.Module):
|
8 |
+
def __init__(self, emb_dim, num_mlp_layers, num_classes, dropout):
|
9 |
+
super().__init__()
|
10 |
+
self.num_mlp_layers = num_mlp_layers
|
11 |
+
|
12 |
+
# LSTM
|
13 |
+
self.lstm = nn.LSTM(emb_dim, emb_dim, batch_first=True)
|
14 |
+
|
15 |
+
# Decoder (MLP)
|
16 |
+
self.mlp = nn.ModuleList()
|
17 |
+
for _ in range(num_mlp_layers):
|
18 |
+
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
|
19 |
+
self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
|
20 |
+
|
21 |
+
# Dropout
|
22 |
+
self.dropout = nn.Dropout(dropout)
|
23 |
+
|
24 |
+
# Initializing weights
|
25 |
+
self.reset_parameters()
|
26 |
+
|
27 |
+
def reset_parameters(self):
|
28 |
+
for param in self.parameters():
|
29 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
30 |
+
param.data.uniform_(-stdv, stdv)
|
31 |
+
|
32 |
+
def forward(self, graph_emb):
|
33 |
+
"""
|
34 |
+
Paramters
|
35 |
+
---------
|
36 |
+
graph_emb: torch.tensor [batch_size x max_seq_length x emb_dim]
|
37 |
+
|
38 |
+
Returns
|
39 |
+
-------
|
40 |
+
probs: torch.tensor [batch_size x max_seq_length x num_classes]
|
41 |
+
probabilities of classes
|
42 |
+
"""
|
43 |
+
#---------------
|
44 |
+
# LSTM encoding
|
45 |
+
#---------------
|
46 |
+
h, _ = self.lstm(graph_emb) # [batch_size x max_seq_length x emb_dim]
|
47 |
+
|
48 |
+
#----------
|
49 |
+
# Decoding
|
50 |
+
#----------
|
51 |
+
for i in range(self.num_mlp_layers):
|
52 |
+
h = self.dropout(h)
|
53 |
+
h = torch.relu(self.mlp[i](h))
|
54 |
+
h = self.dropout(h)
|
55 |
+
logits = self.mlp[-1](h)
|
56 |
+
probs = F.log_softmax(logits, dim=-1)
|
57 |
+
return probs
|
models/classifiers/nn_classifiers/decoders/mha_decoder.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class PositionalEncoding(nn.Module):
|
7 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
8 |
+
super().__init__()
|
9 |
+
self.dropout = nn.Dropout(p=dropout)
|
10 |
+
|
11 |
+
position = torch.arange(max_len).unsqueeze(1)
|
12 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
13 |
+
pe = torch.zeros(1, max_len, d_model) # batch_first
|
14 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
15 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
16 |
+
self.register_buffer('pe', pe)
|
17 |
+
|
18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Arguments:
|
21 |
+
x: Tensor, shape ``[batch_size x max_seq_length x embedding_dim]``
|
22 |
+
"""
|
23 |
+
x = x + self.pe[:, :x.size(1), :]
|
24 |
+
return self.dropout(x)
|
25 |
+
|
26 |
+
class SelfMHADecoder(nn.Module):
|
27 |
+
def __init__(self, emb_dim: int, num_heads: int, num_mha_layers: int, num_classes: int, dropout: float, pos_encoder: str = None, max_len: int = 100):
|
28 |
+
super().__init__()
|
29 |
+
self.num_mha_layers = num_mha_layers
|
30 |
+
|
31 |
+
# positional encoding
|
32 |
+
self.pos_encoder_type = pos_encoder
|
33 |
+
if pos_encoder == "sincos":
|
34 |
+
self.pos_encoder = PositionalEncoding(d_model=emb_dim, dropout=dropout, max_len=max_len)
|
35 |
+
|
36 |
+
# MHA blocks
|
37 |
+
mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim,
|
38 |
+
nhead=num_heads,
|
39 |
+
dim_feedforward=emb_dim,
|
40 |
+
dropout=dropout,
|
41 |
+
batch_first=True)
|
42 |
+
self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)
|
43 |
+
|
44 |
+
# linear projection for adjusting out_dim to num_classes
|
45 |
+
self.out_linear = nn.Linear(emb_dim, num_classes, bias=True)
|
46 |
+
|
47 |
+
# Initializing weights
|
48 |
+
self.reset_parameters()
|
49 |
+
|
50 |
+
def reset_parameters(self):
|
51 |
+
for param in self.parameters():
|
52 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
53 |
+
param.data.uniform_(-stdv, stdv)
|
54 |
+
|
55 |
+
def forward(self, edge_emb):
|
56 |
+
"""
|
57 |
+
Paramters
|
58 |
+
---------
|
59 |
+
graph_emb: torch.tensor [batch_size x max_seq_length x emb_dim]
|
60 |
+
|
61 |
+
Returns
|
62 |
+
-------
|
63 |
+
probs: torch.tensor [batch_size x max_seq_length x num_classes]
|
64 |
+
probabilities of classes
|
65 |
+
"""
|
66 |
+
#---------------
|
67 |
+
# MHA decoding
|
68 |
+
#---------------
|
69 |
+
if self.pos_encoder_type == "sincos":
|
70 |
+
edge_emb = self.pos_encoder(edge_emb)
|
71 |
+
h = self.mha(edge_emb, is_causal=True) # [batch_size x max_seq_length x emb_dim]
|
72 |
+
logits = self.out_linear(h)
|
73 |
+
probs = F.log_softmax(logits, dim=-1)
|
74 |
+
return probs
|
models/classifiers/nn_classifiers/decoders/mlp_decoder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class MLPDecoder(nn.Module):
|
8 |
+
def __init__(self, emb_dim, num_mlp_layers, num_classes, dropout):
|
9 |
+
super().__init__()
|
10 |
+
self.num_mlp_layers = num_mlp_layers
|
11 |
+
|
12 |
+
# Decoder (MLP)
|
13 |
+
self.mlp = nn.ModuleList()
|
14 |
+
for _ in range(num_mlp_layers):
|
15 |
+
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
|
16 |
+
self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
|
17 |
+
|
18 |
+
# Dropout
|
19 |
+
self.dropout = nn.Dropout(dropout)
|
20 |
+
|
21 |
+
# Initializing weights
|
22 |
+
self.reset_parameters()
|
23 |
+
|
24 |
+
def reset_parameters(self):
|
25 |
+
for param in self.parameters():
|
26 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
27 |
+
param.data.uniform_(-stdv, stdv)
|
28 |
+
|
29 |
+
def forward(self, graph_emb):
|
30 |
+
"""
|
31 |
+
Paramters
|
32 |
+
---------
|
33 |
+
graph_emb: torch.tensor [batch_size x emb_dim]
|
34 |
+
|
35 |
+
Returns
|
36 |
+
-------
|
37 |
+
probs: torch.tensor [batch_size x num_classes]
|
38 |
+
probabilities of classes
|
39 |
+
"""
|
40 |
+
#----------
|
41 |
+
# Decoding
|
42 |
+
#----------
|
43 |
+
h = graph_emb
|
44 |
+
for i in range(self.num_mlp_layers):
|
45 |
+
h = self.dropout(h)
|
46 |
+
h = torch.relu(self.mlp[i](h))
|
47 |
+
h = self.dropout(h)
|
48 |
+
logits = self.mlp[-1](h)
|
49 |
+
probs = F.log_softmax(logits, dim=-1)
|
50 |
+
return probs
|
models/classifiers/nn_classifiers/encoders/attn_edge_encoder.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class AttentionEdgeEncoder(nn.Module):
|
6 |
+
def __init__(self, state_dim, emb_dim, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.state_dim = state_dim
|
9 |
+
self.emb_dim = emb_dim
|
10 |
+
self.norm_factor = 1 / math.sqrt(emb_dim)
|
11 |
+
|
12 |
+
# initial embedding for state
|
13 |
+
if state_dim > 0:
|
14 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
15 |
+
|
16 |
+
# An attention layer
|
17 |
+
self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
|
18 |
+
self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
19 |
+
self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
20 |
+
|
21 |
+
# out linear layer
|
22 |
+
self.out_linear = nn.Linear(emb_dim, emb_dim)
|
23 |
+
|
24 |
+
# Dropout
|
25 |
+
self.dropout = nn.Dropout(dropout)
|
26 |
+
|
27 |
+
self.reset_parameters()
|
28 |
+
|
29 |
+
def reset_parameters(self):
|
30 |
+
for param in self.parameters():
|
31 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
32 |
+
param.data.uniform_(-stdv, stdv)
|
33 |
+
|
34 |
+
def forward(self, inputs, node_emb):
|
35 |
+
"""
|
36 |
+
Paramters
|
37 |
+
---------
|
38 |
+
inputs: dict
|
39 |
+
curr_node_id: torch.LongTensor [batch_size]
|
40 |
+
next_node_id: torch.LongTensor [batch_size]
|
41 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
42 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
43 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
44 |
+
node embeddings obtained from the node encoder
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
h: torch.tensor [batch_size x emb_dim]
|
49 |
+
edge embeddings
|
50 |
+
"""
|
51 |
+
curr_node_id = inputs["curr_node_id"]
|
52 |
+
next_node_id = inputs["next_node_id"]
|
53 |
+
mask = inputs["mask"]
|
54 |
+
state = inputs["state"]
|
55 |
+
batch_size = curr_node_id.size(0)
|
56 |
+
|
57 |
+
#--------------------------------
|
58 |
+
# generate queries, keys, values
|
59 |
+
#--------------------------------
|
60 |
+
node_emb = self.dropout(node_emb)
|
61 |
+
curr_emb = node_emb.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
62 |
+
next_emb = node_emb.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
63 |
+
if state is not None and self.state_dim > 0:
|
64 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
65 |
+
input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
|
66 |
+
else:
|
67 |
+
input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
|
68 |
+
input_kv = torch.cat((curr_emb.expand_as(node_emb), node_emb), -1) # [batch_size x num_nodes x (2*emb_dim)]
|
69 |
+
|
70 |
+
#--------------------
|
71 |
+
# An attention layer
|
72 |
+
#--------------------
|
73 |
+
q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
|
74 |
+
k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
|
75 |
+
v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
|
76 |
+
compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
|
77 |
+
compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
|
78 |
+
attn = torch.softmax(compatibility, dim=-1)
|
79 |
+
h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
|
80 |
+
h = h.squeeze(1) # [batch_size x emb_dim]
|
81 |
+
return self.out_linear(h) + q.squeeze(1)
|
models/classifiers/nn_classifiers/encoders/concat_edge_encoder.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class ConcatEdgeEncoder(nn.Module):
|
6 |
+
def __init__(self, state_dim, emb_dim, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.state_dim = state_dim
|
9 |
+
self.emb_dim = emb_dim
|
10 |
+
self.norm_factor = 1 / math.sqrt(emb_dim)
|
11 |
+
|
12 |
+
# initial embedding for state
|
13 |
+
if state_dim > 0:
|
14 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
15 |
+
|
16 |
+
# out linear layer
|
17 |
+
self.out_linear = nn.Linear((2 + int(state_dim > 0)) * emb_dim, emb_dim)
|
18 |
+
|
19 |
+
# Dropout
|
20 |
+
self.dropout = nn.Dropout(dropout)
|
21 |
+
|
22 |
+
self.reset_parameters()
|
23 |
+
|
24 |
+
def reset_parameters(self):
|
25 |
+
for param in self.parameters():
|
26 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
27 |
+
param.data.uniform_(-stdv, stdv)
|
28 |
+
|
29 |
+
def forward(self, inputs, node_emb):
|
30 |
+
"""
|
31 |
+
Paramters
|
32 |
+
---------
|
33 |
+
inputs: dict
|
34 |
+
curr_node_id: torch.LongTensor [batch_size]
|
35 |
+
next_node_id: torch.LongTensor [batch_size]
|
36 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
37 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
38 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
39 |
+
node embeddings obtained from the node encoder
|
40 |
+
|
41 |
+
Returns
|
42 |
+
-------
|
43 |
+
h: torch.tensor [batch_size x emb_dim]
|
44 |
+
edge embeddings
|
45 |
+
"""
|
46 |
+
curr_node_id = inputs["curr_node_id"]
|
47 |
+
next_node_id = inputs["next_node_id"]
|
48 |
+
state = inputs["state"]
|
49 |
+
batch_size = curr_node_id.size(0)
|
50 |
+
|
51 |
+
#--------------------------------
|
52 |
+
# generate queries, keys, values
|
53 |
+
#--------------------------------
|
54 |
+
node_emb = self.dropout(node_emb)
|
55 |
+
curr_emb = node_emb.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
56 |
+
next_emb = node_emb.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
|
57 |
+
if state is not None and self.state_dim > 0:
|
58 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
59 |
+
edge_emb = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
|
60 |
+
else:
|
61 |
+
edge_emb = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
|
62 |
+
edge_emb = edge_emb.squeeze(1) # [batch_size x (2*emb_dim)]
|
63 |
+
return self.out_linear(edge_emb)
|
models/classifiers/nn_classifiers/encoders/max_readout.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class MaxReadout(nn.Module):
|
6 |
+
def __init__(self, state_dim, emb_dim, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.state_dim = state_dim
|
9 |
+
self.emb_dim = emb_dim
|
10 |
+
|
11 |
+
# initial embedding for state
|
12 |
+
if state_dim > 0:
|
13 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
14 |
+
|
15 |
+
# out linear layer
|
16 |
+
self.out_linear = nn.Linear((1 + int(state_dim > 0))*emb_dim, emb_dim)
|
17 |
+
|
18 |
+
# Dropout
|
19 |
+
self.dropout = nn.Dropout(dropout)
|
20 |
+
|
21 |
+
self.reset_parameters()
|
22 |
+
|
23 |
+
def reset_parameters(self):
|
24 |
+
for param in self.parameters():
|
25 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
26 |
+
param.data.uniform_(-stdv, stdv)
|
27 |
+
|
28 |
+
def forward(self, inputs, node_emb):
|
29 |
+
"""
|
30 |
+
Paramters
|
31 |
+
---------
|
32 |
+
inputs: dict
|
33 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
34 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
35 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
36 |
+
node embeddings obtained from the node encoder
|
37 |
+
|
38 |
+
Returns
|
39 |
+
-------
|
40 |
+
h: torch.tensor [batch_size x emb_dim]
|
41 |
+
graph embeddings
|
42 |
+
"""
|
43 |
+
mask = inputs["mask"]
|
44 |
+
state = inputs["state"]
|
45 |
+
node_emb = self.dropout(node_emb)
|
46 |
+
|
47 |
+
# pooling with a mask
|
48 |
+
mask = mask.unsqueeze(-1).expand_as(node_emb)
|
49 |
+
node_emb = node_emb * mask
|
50 |
+
h, _ = torch.max(node_emb, dim=1) # [batch_size x emb_dim]
|
51 |
+
|
52 |
+
# out linear layer
|
53 |
+
if state is not None and self.state_dim > 0:
|
54 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
55 |
+
h = torch.cat((h, state_emb), -1) # [batch_size x (2*emb_dim)]
|
56 |
+
|
57 |
+
return self.out_linear(h)
|
models/classifiers/nn_classifiers/encoders/mean_readout.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class MeanReadout(nn.Module):
|
6 |
+
def __init__(self, state_dim, emb_dim, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.state_dim = state_dim
|
9 |
+
self.emb_dim = emb_dim
|
10 |
+
|
11 |
+
# initial embedding for state
|
12 |
+
if state_dim > 0:
|
13 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
14 |
+
|
15 |
+
# out linear layer
|
16 |
+
self.out_linear = nn.Linear((1 + int(state_dim > 0))*emb_dim, emb_dim)
|
17 |
+
|
18 |
+
# Dropout
|
19 |
+
self.dropout = nn.Dropout(dropout)
|
20 |
+
|
21 |
+
self.reset_parameters()
|
22 |
+
|
23 |
+
def reset_parameters(self):
|
24 |
+
for param in self.parameters():
|
25 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
26 |
+
param.data.uniform_(-stdv, stdv)
|
27 |
+
|
28 |
+
def forward(self, inputs, node_emb):
|
29 |
+
"""
|
30 |
+
Paramters
|
31 |
+
---------
|
32 |
+
inputs: dict
|
33 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
34 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
35 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
36 |
+
node embeddings obtained from the node encoder
|
37 |
+
|
38 |
+
Returns
|
39 |
+
-------
|
40 |
+
h: torch.tensor [batch_size x emb_dim]
|
41 |
+
graph embeddings
|
42 |
+
"""
|
43 |
+
mask = inputs["mask"]
|
44 |
+
state = inputs["state"]
|
45 |
+
node_emb = self.dropout(node_emb)
|
46 |
+
|
47 |
+
# pooling with a mask
|
48 |
+
mask = mask.unsqueeze(-1).expand_as(node_emb)
|
49 |
+
node_emb = node_emb * mask
|
50 |
+
h = torch.mean(node_emb, dim=1) # [batch_size x emb_dim]
|
51 |
+
|
52 |
+
# out linear layer
|
53 |
+
if state is not None and self.state_dim > 0:
|
54 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
55 |
+
h = torch.cat((h, state_emb), -1) # [batch_size x (2*emb_dim)]
|
56 |
+
|
57 |
+
return self.out_linear(h)
|
models/classifiers/nn_classifiers/encoders/mha_node_encoder.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class SelfMHANodeEncoder(nn.Module):
|
6 |
+
def __init__(self, coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.coord_dim = coord_dim
|
9 |
+
self.node_dim = node_dim
|
10 |
+
self.emb_dim = emb_dim
|
11 |
+
self.num_mha_layers = num_mha_layers
|
12 |
+
|
13 |
+
# initial embedding
|
14 |
+
self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
|
15 |
+
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
|
16 |
+
|
17 |
+
# MHA Encoder (w/o positional encoding)
|
18 |
+
mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim,
|
19 |
+
nhead=num_heads,
|
20 |
+
dim_feedforward=emb_dim,
|
21 |
+
dropout=dropout,
|
22 |
+
batch_first=True)
|
23 |
+
self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)
|
24 |
+
|
25 |
+
# Initializing weights
|
26 |
+
self.reset_parameters()
|
27 |
+
|
28 |
+
def reset_parameters(self):
|
29 |
+
for param in self.parameters():
|
30 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
31 |
+
param.data.uniform_(-stdv, stdv)
|
32 |
+
|
33 |
+
def forward(self, inputs):
|
34 |
+
"""
|
35 |
+
Paramters
|
36 |
+
---------
|
37 |
+
inputs: dict
|
38 |
+
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
43 |
+
node embeddings
|
44 |
+
"""
|
45 |
+
#----------------
|
46 |
+
# input features
|
47 |
+
#----------------
|
48 |
+
node_feat = inputs["node_feats"]
|
49 |
+
|
50 |
+
#------------------------------------------------------------------------
|
51 |
+
# initial linear projection for adjusting dimensions of locs & the depot
|
52 |
+
#------------------------------------------------------------------------
|
53 |
+
# node_feat = self.dropout(node_feat)
|
54 |
+
loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
|
55 |
+
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
|
56 |
+
node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
|
57 |
+
|
58 |
+
#--------------
|
59 |
+
# MLP encoding
|
60 |
+
#--------------
|
61 |
+
node_emb = self.mha(node_emb) # [batch_size x num_nodes x emb_dim]
|
62 |
+
|
63 |
+
return node_emb
|
models/classifiers/nn_classifiers/encoders/mlp_node_encoder.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
class MLPNodeEncoder(nn.Module):
|
6 |
+
def __init__(self, coord_dim, node_dim, emb_dim, num_mlp_layers, dropout):
|
7 |
+
super().__init__()
|
8 |
+
self.coord_dim = coord_dim
|
9 |
+
self.node_dim = node_dim
|
10 |
+
self.emb_dim = emb_dim
|
11 |
+
self.num_mlp_layers = num_mlp_layers
|
12 |
+
|
13 |
+
# initial embedding
|
14 |
+
self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
|
15 |
+
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
|
16 |
+
|
17 |
+
# MLP Encoder
|
18 |
+
self.mlp = nn.ModuleList()
|
19 |
+
for _ in range(num_mlp_layers):
|
20 |
+
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
|
21 |
+
|
22 |
+
# Dropout
|
23 |
+
self.dropout = nn.Dropout(dropout)
|
24 |
+
|
25 |
+
def reset_parameters(self):
|
26 |
+
for param in self.parameters():
|
27 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
28 |
+
param.data.uniform_(-stdv, stdv)
|
29 |
+
|
30 |
+
def forward(self, inputs):
|
31 |
+
"""
|
32 |
+
Paramters
|
33 |
+
---------
|
34 |
+
inputs: dict
|
35 |
+
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
|
36 |
+
|
37 |
+
Returns
|
38 |
+
-------
|
39 |
+
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
|
40 |
+
node embeddings
|
41 |
+
"""
|
42 |
+
#----------------
|
43 |
+
# input features
|
44 |
+
#----------------
|
45 |
+
node_feat = inputs["node_feats"]
|
46 |
+
|
47 |
+
#------------------------------------------------------------------------
|
48 |
+
# initial linear projection for adjusting dimensions of locs & the depot
|
49 |
+
#------------------------------------------------------------------------
|
50 |
+
# node_feat = self.dropout(node_feat)
|
51 |
+
loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
|
52 |
+
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
|
53 |
+
node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
|
54 |
+
|
55 |
+
#--------------
|
56 |
+
# MLP encoding
|
57 |
+
#--------------
|
58 |
+
for i in range(self.num_mlp_layers):
|
59 |
+
# node_emb = self.dropout(node_emb)
|
60 |
+
node_emb = self.mlp[i](node_emb)
|
61 |
+
if i != self.num_mlp_layers - 1:
|
62 |
+
node_emb = torch.relu(node_emb)
|
63 |
+
return node_emb
|
models/classifiers/nn_classifiers/nn_classifier.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
# Node encoder
|
7 |
+
from models.classifiers.nn_classifiers.encoders.mlp_node_encoder import MLPNodeEncoder
|
8 |
+
from models.classifiers.nn_classifiers.encoders.mha_node_encoder import SelfMHANodeEncoder
|
9 |
+
# Edge encoder
|
10 |
+
from models.classifiers.nn_classifiers.encoders.attn_edge_encoder import AttentionEdgeEncoder
|
11 |
+
from models.classifiers.nn_classifiers.encoders.concat_edge_encoder import ConcatEdgeEncoder
|
12 |
+
# Decoder
|
13 |
+
from models.classifiers.nn_classifiers.decoders.lstm_decoder import LSTMDecoder
|
14 |
+
from models.classifiers.nn_classifiers.decoders.mlp_decoder import MLPDecoder
|
15 |
+
from models.classifiers.nn_classifiers.decoders.mha_decoder import SelfMHADecoder
|
16 |
+
|
17 |
+
# data loader
|
18 |
+
from utils.data_utils.tsptw_dataset import load_tsptw_sequentially
|
19 |
+
from utils.data_utils.pctsp_dataset import load_pctsp_sequentially
|
20 |
+
from utils.data_utils.pctsptw_dataset import load_pctsptw_sequentially
|
21 |
+
from utils.data_utils.cvrp_dataset import load_cvrp_sequentially
|
22 |
+
|
23 |
+
NODE_ENC_LIST = ["mlp", "mha"]
|
24 |
+
EDGE_ENC_LIST = ["concat", "attn"]
|
25 |
+
DEC_LIST = ["mlp", "lstm", "mha"]
|
26 |
+
|
27 |
+
class NNClassifier(nn.Module):
|
28 |
+
def __init__(self,
|
29 |
+
problem: str,
|
30 |
+
node_enc_type: str,
|
31 |
+
edge_enc_type: str,
|
32 |
+
dec_type: str,
|
33 |
+
emb_dim: int,
|
34 |
+
num_enc_mlp_layers: int,
|
35 |
+
num_dec_mlp_layers: int,
|
36 |
+
num_classes: int,
|
37 |
+
dropout: float,
|
38 |
+
pos_encoder: str = "sincos"):
|
39 |
+
super().__init__()
|
40 |
+
self.problem = problem
|
41 |
+
self.node_enc_type = node_enc_type
|
42 |
+
self.edge_enc_type = edge_enc_type
|
43 |
+
self.dec_type = dec_type
|
44 |
+
assert node_enc_type in NODE_ENC_LIST, f"Invalid enc_type. select from {NODE_ENC_LIST}"
|
45 |
+
assert dec_type in DEC_LIST, f"Invalid dec_type. select from {DEC_LIST}"
|
46 |
+
self.is_sequential = True if dec_type in ["lstm", "mha"] else False
|
47 |
+
coord_dim = 2 # only support 2d problem
|
48 |
+
if problem == "tsptw":
|
49 |
+
node_dim = 4 # coords (2) + time window (2)
|
50 |
+
state_dim = 1 # current time (1)
|
51 |
+
elif problem == "pctsp":
|
52 |
+
node_dim = 4 # coords (2) + prize (1) + penalty (1)
|
53 |
+
state_dim = 2 # current prize (1) + current penalty (1)
|
54 |
+
elif problem == "pctsptw":
|
55 |
+
node_dim = 6 # coords (2) + prize (1) + penalty (1) + time window (2)
|
56 |
+
state_dim = 3 # current prize (1) + current penalty (1) + current time (1)
|
57 |
+
elif problem == "cvrp":
|
58 |
+
node_dim = 3 # coords (2) + demand (1)
|
59 |
+
state_dim = 1 # remaining capacity (1)
|
60 |
+
else:
|
61 |
+
NotImplementedError
|
62 |
+
|
63 |
+
#----------------
|
64 |
+
# Graph encoding
|
65 |
+
#----------------
|
66 |
+
# Node encoder
|
67 |
+
if node_enc_type == "mlp":
|
68 |
+
self.node_enc = MLPNodeEncoder(coord_dim, node_dim, emb_dim, num_enc_mlp_layers, dropout)
|
69 |
+
elif node_enc_type == "mha":
|
70 |
+
num_heads = 8
|
71 |
+
num_mha_layers = 2
|
72 |
+
self.node_enc = SelfMHANodeEncoder(coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout)
|
73 |
+
else:
|
74 |
+
raise NotImplementedError
|
75 |
+
|
76 |
+
# Readout
|
77 |
+
if edge_enc_type == "concat":
|
78 |
+
self.readout = ConcatEdgeEncoder(state_dim, emb_dim, dropout)
|
79 |
+
elif edge_enc_type == "attn":
|
80 |
+
self.readout = AttentionEdgeEncoder(state_dim, emb_dim, dropout)
|
81 |
+
else:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
#------------------------
|
85 |
+
# Classification Decoder
|
86 |
+
#------------------------
|
87 |
+
if dec_type == "mlp":
|
88 |
+
self.decoder = MLPDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
|
89 |
+
elif dec_type == "lstm":
|
90 |
+
self.decoder = LSTMDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
|
91 |
+
elif dec_type == "mha":
|
92 |
+
num_heads = 8
|
93 |
+
num_mha_layers = 2
|
94 |
+
self.decoder = SelfMHADecoder(emb_dim, num_heads, num_mha_layers, num_classes, dropout, pos_encoder)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, inputs):
|
99 |
+
"""
|
100 |
+
Paramters
|
101 |
+
---------
|
102 |
+
inputs: dict
|
103 |
+
curr_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
|
104 |
+
next_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
|
105 |
+
node_feat: torch.FloatTensor [batch_size x max_seq_length x num_nodes x node_dim] if self.sequential else [batch_size x num_nodes x node_dim]
|
106 |
+
mask: torch.LongTensor [batch_size x max_seq_length x num_nodes] if self.sequential else [batch_size x num_nodes]
|
107 |
+
state: torch.FloatTensor [batch_size x max_seq_length x state_dim] if self.sequential else [batch_size x state_dim]
|
108 |
+
|
109 |
+
Returns
|
110 |
+
-------
|
111 |
+
probs: torch.tensor [batch_size x seq_length x num_classes] if self.sequential else [batch_size x num_classes]
|
112 |
+
probabilities of classes
|
113 |
+
"""
|
114 |
+
#-----------------
|
115 |
+
# Encoding graphs
|
116 |
+
#-----------------
|
117 |
+
if self.is_sequential:
|
118 |
+
shp = inputs["curr_node_id"].size()
|
119 |
+
inputs = {key: value.flatten(0, 1) for key, value in inputs.items()}
|
120 |
+
node_emb = self.node_enc(inputs) # [(batch_size*max_seq_length) x emb_dim] if self.sequential else [batch_size x emb_dim]
|
121 |
+
graph_emb = self.readout(inputs, node_emb)
|
122 |
+
if self.is_sequential:
|
123 |
+
graph_emb = graph_emb.view(*shp, -1) # [batch_size x max_seq_length x emb_dim]
|
124 |
+
|
125 |
+
#----------
|
126 |
+
# Decoding
|
127 |
+
#----------
|
128 |
+
probs = self.decoder(graph_emb)
|
129 |
+
|
130 |
+
return probs
|
131 |
+
|
132 |
+
def get_inputs(self, routes, first_explained_step, node_feats):
|
133 |
+
node_feats_ = node_feats.copy()
|
134 |
+
node_feats_["tour"] = routes
|
135 |
+
if self.problem == "tsptw":
|
136 |
+
seq_data = load_tsptw_sequentially(node_feats_)
|
137 |
+
elif self.problem == "pctsp":
|
138 |
+
seq_data = load_pctsp_sequentially(node_feats_)
|
139 |
+
elif self.problem == "pctsptw":
|
140 |
+
seq_data = load_pctsptw_sequentially(node_feats_)
|
141 |
+
elif self.problem == "cvrp":
|
142 |
+
seq_data = load_cvrp_sequentially(node_feats_)
|
143 |
+
else:
|
144 |
+
NotImplementedError
|
145 |
+
|
146 |
+
def pad_seq_length(batch):
|
147 |
+
data = {}
|
148 |
+
for key in batch[0].keys():
|
149 |
+
padding_value = True if key == "mask" else 0.0
|
150 |
+
# post-padding
|
151 |
+
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
|
152 |
+
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
|
153 |
+
data.update({"pad_mask": pad_mask})
|
154 |
+
return data
|
155 |
+
instance = pad_seq_length(seq_data)
|
156 |
+
return instance
|
models/classifiers/predictor.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
#------------
|
8 |
+
# base class
|
9 |
+
#------------
|
10 |
+
class DecisionPredictorBase(nn.Module):
|
11 |
+
def __init__(self, coord_dim, node_dim, state_dim, emb_dim, num_mlp_layers, num_classes, dropout):
|
12 |
+
super().__init__()
|
13 |
+
self.coord_dim = coord_dim
|
14 |
+
self.node_dim = node_dim
|
15 |
+
self.emb_dim = emb_dim
|
16 |
+
self.state_dim = state_dim
|
17 |
+
self.num_mlp_layers = num_mlp_layers
|
18 |
+
self.norm_factor = 1 / math.sqrt(emb_dim)
|
19 |
+
|
20 |
+
# initial embedding
|
21 |
+
self.init_linear_node = nn.Linear(node_dim, emb_dim)
|
22 |
+
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
|
23 |
+
if state_dim > 0:
|
24 |
+
self.init_linear_state = nn.Linear(state_dim, emb_dim)
|
25 |
+
|
26 |
+
# An attention layer
|
27 |
+
self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
|
28 |
+
self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
29 |
+
self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
|
30 |
+
|
31 |
+
# MLP
|
32 |
+
self.mlp = nn.ModuleList()
|
33 |
+
for i in range(self.num_mlp_layers):
|
34 |
+
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
|
35 |
+
self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
|
36 |
+
|
37 |
+
# Dropout
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
|
40 |
+
self.reset_parameters()
|
41 |
+
|
42 |
+
def reset_parameters(self):
|
43 |
+
for param in self.parameters():
|
44 |
+
stdv = 1. / math.sqrt(param.size(-1))
|
45 |
+
param.data.uniform_(-stdv, stdv)
|
46 |
+
|
47 |
+
def forward(self, inputs):
|
48 |
+
"""
|
49 |
+
Paramters
|
50 |
+
---------
|
51 |
+
inputs: dict
|
52 |
+
curr_node_id: torch.LongTensor [batch_size]
|
53 |
+
next_node_id: torch.LongTensor [batch_size]
|
54 |
+
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
|
55 |
+
mask: torch.LongTensor [batch_size x num_nodes]
|
56 |
+
state: torch.FloatTensor [batch_size x state_dim]
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
probs: torch.tensor [batch_size x num_classes]
|
61 |
+
"""
|
62 |
+
#----------------
|
63 |
+
# input features
|
64 |
+
#----------------
|
65 |
+
curr_node_id = inputs["curr_node_id"]
|
66 |
+
next_node_id = inputs["next_node_id"]
|
67 |
+
node_feat = inputs["node_feats"]
|
68 |
+
mask = inputs["mask"]
|
69 |
+
state = inputs["state"]
|
70 |
+
|
71 |
+
#---------------------------
|
72 |
+
# initial linear projection
|
73 |
+
#---------------------------
|
74 |
+
node_emb = self.init_linear_node(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
|
75 |
+
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
|
76 |
+
new_node_feat = torch.cat((depot_emb, node_emb), 1) # [batch_size x num_nodes x emb_dim]
|
77 |
+
new_node_feat = self.dropout(new_node_feat)
|
78 |
+
|
79 |
+
#---------------
|
80 |
+
# preprocessing
|
81 |
+
#---------------
|
82 |
+
batch_size = curr_node_id.size(0)
|
83 |
+
curr_emb = new_node_feat.gather(1, curr_node_id.unsqueeze(-1).expand(batch_size, 1, self.emb_dim))
|
84 |
+
next_emb = new_node_feat.gather(1, next_node_id.unsqueeze(-1).expand(batch_size, 1, self.emb_dim))
|
85 |
+
if state is not None and self.state_dim > 0:
|
86 |
+
state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
|
87 |
+
input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
|
88 |
+
else:
|
89 |
+
input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
|
90 |
+
input_kv = torch.cat((curr_emb.expand_as(new_node_feat), new_node_feat), -1) # [batch_size x num_nodes x (2*emb_dim)]
|
91 |
+
|
92 |
+
#--------------------
|
93 |
+
# An attention layer
|
94 |
+
#--------------------
|
95 |
+
q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
|
96 |
+
k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
|
97 |
+
v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
|
98 |
+
compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
|
99 |
+
compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
|
100 |
+
attn = torch.softmax(compatibility, dim=-1)
|
101 |
+
h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
|
102 |
+
h = h.squeeze(1) # [batch_size x emb_dim]
|
103 |
+
|
104 |
+
#---------------
|
105 |
+
# MLP (decoder)
|
106 |
+
#---------------
|
107 |
+
for i in range(self.num_mlp_layers):
|
108 |
+
h = self.dropout(h)
|
109 |
+
h = torch.relu(self.mlp[i](h))
|
110 |
+
h = self.dropout(h)
|
111 |
+
logits = self.mlp[-1](h)
|
112 |
+
probs = F.log_softmax(logits, dim=-1)
|
113 |
+
return probs
|
114 |
+
|
115 |
+
def get_inputs(self, tour, first_explained_step, node_feats):
|
116 |
+
"""
|
117 |
+
For TSPTW
|
118 |
+
TODO: refactoring
|
119 |
+
|
120 |
+
Parameters
|
121 |
+
----------
|
122 |
+
tour: list [seq_length]
|
123 |
+
first_explained_step: int
|
124 |
+
node_feats np.array [num_nodes x node_dim]
|
125 |
+
|
126 |
+
Returns
|
127 |
+
-------
|
128 |
+
out: dict (key: data type [data_size])
|
129 |
+
curr_node_id: torch.tensor [num_explained_paths]
|
130 |
+
next_node_id: torch.tensor [num_explained_paths]
|
131 |
+
node_feats: torch.tensor [num_explained_paths x num_nodes x node_dim]
|
132 |
+
mask: torch.tensor [num_explained_paths x num_nodes]
|
133 |
+
state: torch.tensor [num_explained_paths x state_dim]
|
134 |
+
"""
|
135 |
+
node_feats = {
|
136 |
+
key: torch.from_numpy(node_feat.astype(np.float32).copy()).clone()
|
137 |
+
if isinstance(node_feat, np.ndarray) else
|
138 |
+
torch.tensor([node_feat])
|
139 |
+
for key, node_feat in node_feats.items()
|
140 |
+
}
|
141 |
+
if isinstance(tour, np.ndarray):
|
142 |
+
tour = torch.from_numpy(tour.astype(np.long).copy()).clone()
|
143 |
+
else:
|
144 |
+
tour = torch.LongTensor(tour)
|
145 |
+
|
146 |
+
out = {"curr_node_id": [], "next_node_id": [], "mask": [], "state": []}
|
147 |
+
for step in range(first_explained_step, len(tour) - 1):
|
148 |
+
# node ids
|
149 |
+
curr_node_id = tour[step]
|
150 |
+
next_node_id = tour[step + 1]
|
151 |
+
# mask & state
|
152 |
+
max_coord = node_feats["grid_size"]
|
153 |
+
coord = node_feats["coords"] / max_coord # [num_nodes x coord_dim]
|
154 |
+
time_window = node_feats["time_window"] # [num_nodes x 2(start, end)]
|
155 |
+
time_window = (time_window - time_window[1:].min()) / time_window[1:].max() # min-max normalization
|
156 |
+
curr_time = torch.FloatTensor([0.0])
|
157 |
+
raw_coord = node_feats["coords"]
|
158 |
+
raw_time_window = node_feats["time_window"]
|
159 |
+
raw_curr_time = torch.FloatTensor([0.0])
|
160 |
+
num_nodes = len(node_feats["coords"])
|
161 |
+
mask = torch.ones(num_nodes, dtype=torch.long) # feasible -> 1, infeasible -> 0
|
162 |
+
for i in range(step + 1):
|
163 |
+
curr_id = tour[i]
|
164 |
+
if i > 0:
|
165 |
+
prev_id = tour[i - 1]
|
166 |
+
raw_curr_time += torch.norm(raw_coord[curr_id] - raw_coord[prev_id])
|
167 |
+
curr_time += torch.norm(coord[curr_id] - coord[prev_id])
|
168 |
+
# visited?
|
169 |
+
mask[curr_id] = 0
|
170 |
+
# curr_time exceeds the time window?
|
171 |
+
mask[curr_time > time_window[:, 1]] = 0
|
172 |
+
curr_time = (raw_curr_time - raw_time_window[1:].min()) / raw_time_window[1:].max() # min-max normalization
|
173 |
+
out["curr_node_id"].append(curr_node_id)
|
174 |
+
out["next_node_id"].append(next_node_id)
|
175 |
+
out["mask"].append(mask)
|
176 |
+
out["state"].append(curr_time)
|
177 |
+
out = {key: torch.stack(value, 0) for key, value in out.items()}
|
178 |
+
node_feats = {
|
179 |
+
key: node_feat.unsqueeze(0).expand(out["mask"].size(0), *node_feat.size())
|
180 |
+
for key, node_feat in node_feats.items()
|
181 |
+
}
|
182 |
+
out.update({"node_feats": node_feats})
|
183 |
+
return out
|
184 |
+
|
185 |
+
#---------------
|
186 |
+
# general class
|
187 |
+
#---------------
|
188 |
+
class DecisionPredictor(DecisionPredictorBase):
|
189 |
+
def __init__(self, problem, emb_dim, num_mlp_layers, num_classes, drop):
|
190 |
+
coord_dim = 2
|
191 |
+
self.problem = problem
|
192 |
+
if problem == "tsptw":
|
193 |
+
node_dim = coord_dim + 2 # + time_window(start, end)
|
194 |
+
state_dim = 1 # current_time
|
195 |
+
elif problem == "cvrp":
|
196 |
+
node_dim = coord_dim + 1 # + demand
|
197 |
+
state_dim = 1 # used_capacity
|
198 |
+
elif problem == "cvrptw":
|
199 |
+
node_dim = coord_dim + 1 + 2 # + demand + time_window(start, end)
|
200 |
+
state_dim = 2 # used_capacity + current_time
|
201 |
+
else:
|
202 |
+
assert False, f"problem {problem} is not supported!"
|
203 |
+
super().__init__(coord_dim, node_dim, state_dim, emb_dim, num_mlp_layers, num_classes, drop)
|
models/classifiers/rule_based_models.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
TOUR_LENGTH = 0
|
5 |
+
TIME_WINDOW = 1
|
6 |
+
|
7 |
+
class kNearestPredictor(nn.Module):
|
8 |
+
def __init__(self, problem, k, k_type):
|
9 |
+
"""
|
10 |
+
Paramters
|
11 |
+
---------
|
12 |
+
problem: str
|
13 |
+
problem type
|
14 |
+
k: float
|
15 |
+
if the vehicle visis k% nearest node, this model labels the visit as prioritizing tour length
|
16 |
+
"""
|
17 |
+
super().__init__()
|
18 |
+
self.problem = problem
|
19 |
+
self.num_classes = 2
|
20 |
+
self.k_type = k_type
|
21 |
+
if k_type == "num":
|
22 |
+
self.k = int(k)
|
23 |
+
elif k_type == "ratio":
|
24 |
+
self.k = k
|
25 |
+
else:
|
26 |
+
assert False, "Invalid k_type. select from [num, ratio]"
|
27 |
+
|
28 |
+
def forward(self, inputs):
|
29 |
+
"""
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
|
33 |
+
Returns
|
34 |
+
-------
|
35 |
+
probs: torch.tensor [batch_size x num_classes]
|
36 |
+
"""
|
37 |
+
#----------------
|
38 |
+
# input features
|
39 |
+
#----------------
|
40 |
+
curr_node_id = inputs["curr_node_id"]
|
41 |
+
next_node_id = inputs["next_node_id"]
|
42 |
+
node_feat = inputs["node_feats"]
|
43 |
+
mask = inputs["mask"]
|
44 |
+
|
45 |
+
coord_dim = 2
|
46 |
+
batch_size = curr_node_id.size(0)
|
47 |
+
coords = node_feat[:, :, :coord_dim] # [batch_size x num_nodes x coord_dim]
|
48 |
+
num_candidates = (mask > 0).sum(dim=-1) # [batch_size]
|
49 |
+
topk = torch.round(num_candidates * self.k).to(torch.long) # [batch_size]
|
50 |
+
curr_coord = coords.gather(1, curr_node_id[:, None, None].expand_as(coords)) # [batch_size x 1 x coord_dim]
|
51 |
+
dist_from_curr_node = torch.norm(curr_coord - coords, dim=-1) # [batch_size x 1 x num_nodes]
|
52 |
+
visit_topk = []
|
53 |
+
for i in range(batch_size):
|
54 |
+
if self.k_type == "num":
|
55 |
+
k = self.k
|
56 |
+
else:
|
57 |
+
k = topk[i].item()
|
58 |
+
id = torch.topk(input=dist_from_curr_node[i], k=k, dim=-1, largest=True)[1]
|
59 |
+
visit_topk.append(torch.isin(next_node_id[i], id))
|
60 |
+
visit_topk = torch.stack(visit_topk, 0)
|
61 |
+
idx = (1 - visit_topk.int()).to(torch.long)
|
62 |
+
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
|
63 |
+
probs.scatter_(-1, idx.unsqueeze(-1).expand_as(probs), 1.0)
|
64 |
+
return probs
|
65 |
+
|
66 |
+
def get_inputs(self, tour, first_explained_step, node_feats):
|
67 |
+
"""
|
68 |
+
For TSPTW
|
69 |
+
TODO: refactoring
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
tour: list [seq_length]
|
74 |
+
first_explained_step: int
|
75 |
+
node_feats np.array [num_nodes x node_dim]
|
76 |
+
|
77 |
+
Returns
|
78 |
+
-------
|
79 |
+
out: dict (key: data type [data_size])
|
80 |
+
curr_node_id: torch.tensor [num_explained_paths]
|
81 |
+
next_node_id: torch.tensor [num_explained_paths]
|
82 |
+
node_feats: torch.tensor [num_explained_paths x num_nodes x node_dim]
|
83 |
+
mask: torch.tensor [num_explained_paths x num_nodes]
|
84 |
+
state: torch.tensor [num_explained_paths x state_dim]
|
85 |
+
"""
|
86 |
+
if isinstance(node_feats, np.ndarray):
|
87 |
+
node_feats = torch.from_numpy(node_feats.astype(np.float32)).clone()
|
88 |
+
tour = torch.LongTensor(tour)
|
89 |
+
coord_dim = 2
|
90 |
+
out = {"curr_node_id": [], "next_node_id": [], "mask": [], "state": []}
|
91 |
+
for step in range(first_explained_step, len(tour) - 1):
|
92 |
+
# node ids
|
93 |
+
curr_node_id = tour[step]
|
94 |
+
next_node_id = tour[step + 1]
|
95 |
+
# mask & state
|
96 |
+
max_coord = 100
|
97 |
+
coord = node_feats[:, coord_dim] / max_coord # [num_nodes x coord_dim]
|
98 |
+
time_window = node_feats[:, coord_dim:] # [num_nodes x 2(start, end)]
|
99 |
+
time_window = (time_window - time_window[1:].min()) / time_window[1:].max() # min-max normalization
|
100 |
+
curr_time = torch.FloatTensor([0.0])
|
101 |
+
raw_coord = node_feats[:, coord_dim]
|
102 |
+
raw_time_window = node_feats[:, coord_dim:]
|
103 |
+
raw_curr_time = torch.FloatTensor([0.0])
|
104 |
+
mask = torch.ones(node_feats.size(0), dtype=torch.long) # feasible -> 1, infeasible -> 0
|
105 |
+
for i in range(step + 1):
|
106 |
+
curr_id = tour[i]
|
107 |
+
if i > 0:
|
108 |
+
prev_id = tour[i - 1]
|
109 |
+
raw_curr_time += torch.norm(raw_coord[curr_id] - raw_coord[prev_id])
|
110 |
+
curr_time += torch.norm(coord[curr_id] - coord[prev_id])
|
111 |
+
# visited?
|
112 |
+
mask[curr_id] = 0
|
113 |
+
# curr_time exceeds the time window?
|
114 |
+
mask[curr_time > time_window[:, 1]] = 0
|
115 |
+
curr_time = (raw_curr_time - raw_time_window[1:].min()) / raw_time_window[1:].max() # min-max normalization
|
116 |
+
out["curr_node_id"].append(curr_node_id)
|
117 |
+
out["next_node_id"].append(next_node_id)
|
118 |
+
out["mask"].append(mask)
|
119 |
+
out["state"].append(curr_time)
|
120 |
+
out = {key: torch.stack(value, 0) for key, value in out.items()}
|
121 |
+
node_feats = node_feats.unsqueeze(0).expand(out["mask"].size(0), node_feats.size(-2), node_feats.size(-1))
|
122 |
+
out.update({"node_feats": node_feats})
|
123 |
+
return out
|
124 |
+
|
125 |
+
def get_topk_ids(self, input, k, dim, largest):
|
126 |
+
"""
|
127 |
+
Parameters
|
128 |
+
----------
|
129 |
+
input: torch.tensor [batch_size x num_nodes x num_nodes]
|
130 |
+
k: torch.tensor [batch_size]
|
131 |
+
dim: int
|
132 |
+
largest: bool
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
topk_ids: torch.tensor [batch_size x num_node x k]
|
137 |
+
"""
|
138 |
+
batch_size = input.size(0)
|
139 |
+
max_k = k.max()
|
140 |
+
ids = []
|
141 |
+
for i in range(batch_size):
|
142 |
+
id = torch.topk(input=input[i], k=k[i].item(), dim=dim, largest=largest)[1]
|
143 |
+
|
144 |
+
# adjust tensor size
|
145 |
+
if id.size(0) == 0:
|
146 |
+
id = torch.full((max_k, ), -1000)
|
147 |
+
elif id.size(0) < max_k:
|
148 |
+
id = torch.cat((id, torch.full((max_k - id.size(0), ), id[0])), -1)
|
149 |
+
ids.append(id)
|
150 |
+
return torch.stack(ids, 0)
|
models/loss_functions.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from utils.utils import batched_bincount
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class GeneralCrossEntropy(nn.Module):
|
8 |
+
def __init__(self, weight_type: str, beta : float = 0.99, is_sequential: bool = True):
|
9 |
+
super().__init__()
|
10 |
+
self.weight_type = weight_type
|
11 |
+
self.beta = beta
|
12 |
+
if weight_type == "seq_cbce":
|
13 |
+
assert is_sequential == True
|
14 |
+
self.loss_func = SeqCBCrossEntropy(beta=beta)
|
15 |
+
elif weight_type == "cbce":
|
16 |
+
self.loss_func = CBCrossEntropy(beta=beta, is_sequential=is_sequential)
|
17 |
+
elif weight_type == "wce":
|
18 |
+
self.loss_func = WeightedCrossEntropy(is_sequential=is_sequential)
|
19 |
+
elif weight_type == "ce":
|
20 |
+
self.loss_func = CrossEntropy(is_sequential=is_sequential)
|
21 |
+
else:
|
22 |
+
NotImplementedError
|
23 |
+
|
24 |
+
def forward(self,
|
25 |
+
preds: torch.Tensor,
|
26 |
+
labels: torch.Tensor,
|
27 |
+
pad_mask: torch.Tensor = None):
|
28 |
+
return self.loss_func(preds, labels, pad_mask)
|
29 |
+
|
30 |
+
|
31 |
+
class SeqCBCrossEntropy(nn.Module):
|
32 |
+
def __init__(self, beta : float = 0.99):
|
33 |
+
super().__init__()
|
34 |
+
self.beta = beta
|
35 |
+
|
36 |
+
def forward(self,
|
37 |
+
preds: torch.Tensor,
|
38 |
+
labels: torch.Tensor,
|
39 |
+
pad_mask: torch.Tensor):
|
40 |
+
"""
|
41 |
+
Sequential Class-alanced Cross Entropy Loss (Our proposal)
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
-----------
|
45 |
+
preds: torch.Tensor [batch_size, max_seq_length, num_classes]
|
46 |
+
labels: torch.Tensor [batch_size, max_seq_length]
|
47 |
+
pad_mask: torch.Tensor [batch_size, max_seq_length]
|
48 |
+
|
49 |
+
Returns
|
50 |
+
-------
|
51 |
+
loss: torch.Tensor [1]
|
52 |
+
"""
|
53 |
+
seq_length_batch = pad_mask.sum(-1) # [batch_size]
|
54 |
+
seq_length_list = torch.unique(seq_length_batch) # [num_unique_seq_length]
|
55 |
+
batch_size = preds.size(0)
|
56 |
+
loss = 0
|
57 |
+
for seq_length in seq_length_list:
|
58 |
+
extracted_batch = (seq_length_batch == seq_length) # [batch_size]
|
59 |
+
extracted_preds = preds[extracted_batch] # [num_extracted_batch]
|
60 |
+
extracted_labels = labels[extracted_batch] # [num_extracted_batch]
|
61 |
+
extracted_batch_size = extracted_labels.size(0)
|
62 |
+
bin = batched_bincount(extracted_labels.T, 1, extracted_preds.size(-1)) # [seq_length x num_classes]
|
63 |
+
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
|
64 |
+
for seq_no in range(seq_length.item()):
|
65 |
+
loss += (extracted_batch_size / batch_size) * F.nll_loss(extracted_preds[:, seq_no], extracted_labels[:, seq_no], weight=weight[seq_no])
|
66 |
+
return loss
|
67 |
+
|
68 |
+
class CBCrossEntropy(nn.Module):
|
69 |
+
def __init__(self, beta : float = 0.99, is_sequential: bool = True):
|
70 |
+
super().__init__()
|
71 |
+
self.beta = beta
|
72 |
+
self.is_sequential = is_sequential
|
73 |
+
|
74 |
+
def forward(self,
|
75 |
+
preds: torch.Tensor,
|
76 |
+
labels: torch.Tensor,
|
77 |
+
pad_mask: torch.Tensor = None):
|
78 |
+
if self.is_sequential:
|
79 |
+
mask = pad_mask.view(-1)
|
80 |
+
preds = preds.view(-1, preds.size(-1))
|
81 |
+
bin = labels.view(-1)[mask].bincount()
|
82 |
+
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
|
83 |
+
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
|
84 |
+
else:
|
85 |
+
bincount = labels.view(-1).bincount()
|
86 |
+
weight = (1 - self.beta) / (1 - self.beta**bincount + 1e-8)
|
87 |
+
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
|
88 |
+
return loss
|
89 |
+
|
90 |
+
class WeightedCrossEntropy(nn.Module):
|
91 |
+
def __init__(self, is_sequential: bool = True, norm: str = "min"):
|
92 |
+
super().__init__()
|
93 |
+
self.is_sequential = is_sequential
|
94 |
+
if norm == "min":
|
95 |
+
self.norm = torch.min
|
96 |
+
elif norm == "max":
|
97 |
+
self.norm = torch.max
|
98 |
+
def forward(self,
|
99 |
+
preds: torch.Tensor,
|
100 |
+
labels: torch.Tensor,
|
101 |
+
pad_mask: torch.Tensor = None):
|
102 |
+
if self.is_sequential:
|
103 |
+
mask = pad_mask.view(-1)
|
104 |
+
preds = preds.view(-1, preds.size(-1))
|
105 |
+
bin = labels.view(-1)[mask].bincount()
|
106 |
+
weight = self.norm(bin) / (bin + 1e-8)
|
107 |
+
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
|
108 |
+
else:
|
109 |
+
bincount = labels.view(-1).bincount()
|
110 |
+
weight = self.norm(bin) / (bin + 1e-8)
|
111 |
+
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
|
112 |
+
return loss
|
113 |
+
|
114 |
+
class CrossEntropy(nn.Module):
|
115 |
+
def __init__(self, is_sequential: bool = True):
|
116 |
+
super().__init__()
|
117 |
+
self.is_sequential = is_sequential
|
118 |
+
|
119 |
+
def forward(self,
|
120 |
+
preds: torch.Tensor,
|
121 |
+
labels: torch.Tensor,
|
122 |
+
pad_mask: torch.Tensor = None):
|
123 |
+
if self.is_sequential:
|
124 |
+
mask = pad_mask.view(-1)
|
125 |
+
preds = preds.view(-1, preds.size(-1))
|
126 |
+
loss = F.nll_loss(preds[mask], labels.view(-1)[mask])
|
127 |
+
else:
|
128 |
+
loss = F.nll_loss(preds, labels.squeeze(-1))
|
129 |
+
return loss
|
models/prompts/generate_explanation.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from langchain.schema import StrOutputParser
|
3 |
+
from langchain_core.runnables.base import Runnable
|
4 |
+
from models.prompts.template_json_base import TemplateJsonBase
|
5 |
+
|
6 |
+
GENERATE_EXPLANATION = """\
|
7 |
+
You are RouteExplainer, an explanation system for justifying a specific edge (i.e., actual edge) in a route automatically generated by a VRP solver.
|
8 |
+
Here, you address a scenario where a tourist (user) wonders why the actual edge was selected at the step in the tourist route and why another edge was not selected instead at that step.
|
9 |
+
As an expert tour guide, you will justify why the actual edge was selected in the route if it outperforms another edge. That helps to convince the tourist of the actual edge or to make the tourist's decision to change to another edge from the actual edge while accepting some disadvantages.
|
10 |
+
Please carefully read the contents below and follow the instructions faithfully.
|
11 |
+
|
12 |
+
[Terminology]
|
13 |
+
The following terms are used here.
|
14 |
+
- Node: A destination.
|
15 |
+
- Edge: A directed edge representing the movement from one node to another.
|
16 |
+
- Edge intention: The underlying purpose of the edge. An edge intention here is either “prioritizing route length (route_len)” or “prioritizing time windows (time_window)”.
|
17 |
+
- Step: The visited order in a route.
|
18 |
+
- Actual edge: A user-specified edge in the optimal route generated by a VRP solver. You will justify this edge in this task.
|
19 |
+
- Counterfactual (CF) edge: A user-specified edge that was not selected at the step of the actual edge in the optimal route but could have been. This is a different edge from the actual edge. The user wonders why the CF edge was not selected at the step instead of the actual edge.
|
20 |
+
- Actual route: The optimal route generated by a VRP solver.
|
21 |
+
- CF route: An alternative route where the CF edge is selected at the step instead of the actual edge. The subsequent edges to the CF edge in the CF route are the best-effort ones.
|
22 |
+
|
23 |
+
[Example]
|
24 |
+
Please refer to the following input-output example when generating a counterfactual explanation.
|
25 |
+
***** START EXAMPLE *****
|
26 |
+
[input]
|
27 |
+
Question:
|
28 |
+
- The question asks about replacing the edge from node2 to node3 with the edge from node2 to node5.
|
29 |
+
Actual route:
|
30 |
+
- route: node1 > node2 > (actual edge) > node3 > node4 > node5 > node6 > node7 > node1
|
31 |
+
- short-term effect (immediate travel time): 20 minutes
|
32 |
+
- long-term effect (total travel time): 100 minutes
|
33 |
+
- missed nodes: none
|
34 |
+
- edge-intention ratio after the actual edge: time_window 75%, route_len 25%
|
35 |
+
CF route:
|
36 |
+
- route: node1 > node2 > (CF edge) > node5 > node6 > node7 > node1
|
37 |
+
- short-term effect (immediate travel time): 10 minutes
|
38 |
+
- long-term effect (total travel time): 77.8 minutes
|
39 |
+
- missed nodes: node3, node4
|
40 |
+
- edge-intention ratio after the CF edge: time_window 100%, route_len 0%
|
41 |
+
Difference between two routes:
|
42 |
+
- short-term effect: The actual route increases it by 10 minutes
|
43 |
+
- long-term effect: The actual route increases it by 22.2 minutes
|
44 |
+
- missed nodes: The actual route visits 2 more nodes
|
45 |
+
- difference of edge-intention ratio after the actual and CF edges: time_window -25%, route_len +25%
|
46 |
+
Planed destination information:
|
47 |
+
- node1: start/end point
|
48 |
+
- node2: none
|
49 |
+
- node3: take lunch
|
50 |
+
- node4: attend a tour
|
51 |
+
- node5: most favorite destination
|
52 |
+
- node6: take dinner
|
53 |
+
- node7: none
|
54 |
+
|
55 |
+
[Explanation]
|
56 |
+
Here are the pros and cons of the actual and CF edges.
|
57 |
+
#### Actual edge:
|
58 |
+
- Pros:
|
59 |
+
- It allows you to visit all your destinations within time windows. That is essential for maximizing your tour experience.
|
60 |
+
- Cons:
|
61 |
+
- Immediate travel time will increase by 10 minutes.
|
62 |
+
- The total travel time will increase by 22.2 minutes, but it is natural because the actual route visits two more nodes than the CF route.
|
63 |
+
- Remarks:
|
64 |
+
- The route balances both prioritizing travel time and time windows.
|
65 |
+
#### CF edge:
|
66 |
+
- Pros:
|
67 |
+
- Immediate travel time will decrease by 10 minutes.
|
68 |
+
- The total travel time will decrease by 22.2 minutes. However, note that this reduction in time is the result of not visiting two nodes.
|
69 |
+
- Cons:
|
70 |
+
- You will miss node3 and node4. You plan to take lunch and attend a tour, so the loss could significantly degrade your tour experience.
|
71 |
+
- Remarks:
|
72 |
+
- You will miss node3 and node4 even if you are constantly pressed for time windows in the subsequent movement
|
73 |
+
#### Summary:
|
74 |
+
- Given the pros and cons and the fact that adhering to time constraints is essential, the actual edge is objectively more optimal.
|
75 |
+
- However, you might prefer the CF edge, despite its cons, depending on your preferences.
|
76 |
+
***** END EXAMPLE *****
|
77 |
+
|
78 |
+
[Instruction]
|
79 |
+
Now, please generate a counterfactual explanation for the [input] below.
|
80 |
+
You MUST keep the following rules:
|
81 |
+
- Summarize the pros and cons, including short-term effects, long-term effects, missed nodes, and edge-intention ratio.
|
82 |
+
- Enrich explanations by leveraging destination information.
|
83 |
+
- Carefully consider causality regarding travel time reduction. If the number of missed nodes is equal, one edge may reduce travel time. However, if a route with missed nodes is quicker, it is due to skipping nodes.
|
84 |
+
- A high route_len ratio emphasizes speed over schedule adherence, while a high time_window ratio prioritizes sticking to a schedule, sacrificing travel efficiency for timely arrivals.
|
85 |
+
- Disucuss edge-intention ratio in "Remarks". Do NOT do it in "Pros" or "Cons".
|
86 |
+
- Travel time efficiency is solely determined by the total travel time.
|
87 |
+
- Never say that all planed destinations are visited if there is even one missed node. If some nodes are missed, you must specify which node are missed.
|
88 |
+
- If the CF edge outperforms the actual edge, you do NOT have to force a justification for the actual edge.
|
89 |
+
- Please associate user's intention that "{intent}" with your summary. If the intention is blank, it means no intention was provided.
|
90 |
+
|
91 |
+
[input]
|
92 |
+
{comparison_results}
|
93 |
+
|
94 |
+
[Explanation]
|
95 |
+
"""
|
96 |
+
|
97 |
+
# - Routes are assessed based on the following priorities: fewer missed nodes are better > shorter total travel time is better.
|
98 |
+
# - In "Summary", Clearly and specifically explain the differences between the actual and CF edges to help the tourist convince the actual edge or make a decision to change to another edge from the actual edge while accepting some cons.
|
99 |
+
# - Ensure consistency in comparisons: the pros of the actual edge should be the cons of the CF edge and vice versa.
|
100 |
+
|
101 |
+
class Template4GenerateExplanation(TemplateJsonBase):
|
102 |
+
parser: Runnable = StrOutputParser()
|
103 |
+
template: str = GENERATE_EXPLANATION
|
104 |
+
prompt: Runnable = PromptTemplate(
|
105 |
+
template=template,
|
106 |
+
input_variables=["comparison_results", "intent"],
|
107 |
+
)
|
108 |
+
|
109 |
+
def _get_output_key(self) -> str:
|
110 |
+
return ""
|
models/prompts/identify_question.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from langchain_core.output_parsers import JsonOutputParser
|
3 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
4 |
+
from langchain_core.runnables.base import Runnable
|
5 |
+
from models.prompts.template_json_base import TemplateJsonBase
|
6 |
+
|
7 |
+
IDENTIFY_QUESTION = """\
|
8 |
+
Given a route and a question that asks what would happen if we replaced a specific edge in the route with another edge, please extract the step number of the replaced edge (i.e., cf_step) and the node id of the destination of the new edge (i.e., cf_visit) from the question, which is written in natural language.
|
9 |
+
Please use the following examples as a reference when you answer:
|
10 |
+
***** START EXAMPLE *****
|
11 |
+
[route info]
|
12 |
+
Nodes(node id, name): (1, node1), (2, node2), (3, node3), (4, node4), (5, node5)
|
13 |
+
Route: node1 > (step1) > node5 > (step2) > node3 > (step3) > node2 > (step4) > node4 > (step 5) > node1
|
14 |
+
[question]
|
15 |
+
Why node3, and why not node2?
|
16 |
+
[outputs]
|
17 |
+
```json
|
18 |
+
{{
|
19 |
+
"success": true,
|
20 |
+
"summary": "The answer asks about replacing the edge from node5 to node3 with the edge from node6 to node2.",
|
21 |
+
"intent": "",
|
22 |
+
“process": "The edge from node5 to node3 is at step2 because of "node5 > (step2) > node3". The node id of the destination of the new edge is 2 (node2). Thus, the final answers are cf_step=2 and cf_visit=2.",
|
23 |
+
"cf_step": 2,
|
24 |
+
"cf_visit": 2,
|
25 |
+
}}
|
26 |
+
```
|
27 |
+
|
28 |
+
[route info]
|
29 |
+
Nodes(node id, name): (1, node1), (2, node2), (3, node3), (4, node4), (5, node5)
|
30 |
+
Route: node1 > (step1) > node5 > (step2) > node3 > (step3) > node2 > (step4) > node4 > (step 5) > node1
|
31 |
+
[quetsion]
|
32 |
+
What if we visited node4 instead of node2? We would personally like to visit node4 first.
|
33 |
+
[outputs]
|
34 |
+
```json
|
35 |
+
{{
|
36 |
+
"success": true,
|
37 |
+
"summary": "The answer asks about replacing the edge from node3 to node2 with the edge from node3 to node4.",
|
38 |
+
“intent": "The user would personally like to visit node4 first""
|
39 |
+
"process": "The edge from node3 to node2 is at step3 because of "node3 > (step3) > node2". The node id of the destination of the new edge is 4 (node4). Thus, the final answers are cf_step=3 and cf_visit=4.",
|
40 |
+
"cf_step": 3,
|
41 |
+
"cf_visit": 4,
|
42 |
+
}}
|
43 |
+
```
|
44 |
+
***** END EXAMPLE *****
|
45 |
+
|
46 |
+
Given the following route and question, please extract the step number of the replaced edge (i.e., cf_step) and the node id of the destination of the new edge (i.e., cf_visit) from the question.
|
47 |
+
Please keep the following rules:
|
48 |
+
- Do not output any sentences outside of JSON format.
|
49 |
+
- {format_instructions}
|
50 |
+
|
51 |
+
[route_info]
|
52 |
+
{route_info}
|
53 |
+
[question]
|
54 |
+
{whynot_question}
|
55 |
+
[outputs]
|
56 |
+
"""
|
57 |
+
|
58 |
+
class WhyNotQuestion(BaseModel):
|
59 |
+
success: bool = Field(description="Whether cf_step and cf_visit are successfully extracted (True) or not (False).")
|
60 |
+
summary: str = Field(description="Your summary for the given question. If success=False, instead state here what information is missing to extract cf_step/cf_visit and what additional information should be clarified (Additionally, provide an example).")
|
61 |
+
intent: str = Field(description="Your summary for user's intent (if provided). If not provided, this is set to ''.")
|
62 |
+
process: str = Field(description="The thought (reasoning) process in extracting cf_step and cf_visit. if success=False, this is set to ''.")
|
63 |
+
cf_step: int = Field(description="The step number of the replaced edge. if success=False, this is set to -1.")
|
64 |
+
cf_visit: int = Field(description="The node id of the destination of the new edge. if success=False, this is set to -1.")
|
65 |
+
|
66 |
+
class Template4IdentifyQuestion(TemplateJsonBase):
|
67 |
+
parser: Runnable = JsonOutputParser(pydantic_object=WhyNotQuestion)
|
68 |
+
template: str = IDENTIFY_QUESTION
|
69 |
+
prompt: Runnable = PromptTemplate(
|
70 |
+
template=template,
|
71 |
+
input_variables=["whynot_question", "route_info"],
|
72 |
+
partial_variables={"format_instructions": parser.get_format_instructions()}
|
73 |
+
)
|
74 |
+
|
75 |
+
def _get_output_key(self) -> str:
|
76 |
+
return ""
|
models/prompts/template_json_base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict, Any
|
3 |
+
from langchain_core.runnables import RunnableLambda
|
4 |
+
from langchain_core.runnables.base import Runnable
|
5 |
+
|
6 |
+
class TemplateJsonBase(ABC):
|
7 |
+
parser: Runnable
|
8 |
+
template: str
|
9 |
+
prompt: Runnable
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
def _get_output_key(self) -> str:
|
13 |
+
raise NotImplementedError
|
14 |
+
|
15 |
+
def get_template(self) -> str:
|
16 |
+
return self.template
|
17 |
+
|
18 |
+
def extract_value(self, input: Dict[str, Any]) -> Any:
|
19 |
+
return input[self._get_output_key()]
|
20 |
+
|
21 |
+
def sandwiches(self,
|
22 |
+
llm: Runnable,
|
23 |
+
extract_value: bool = False) -> Runnable:
|
24 |
+
if extract_value:
|
25 |
+
return self.prompt | llm | self.parser | RunnableLambda(self.extract_value)
|
26 |
+
else:
|
27 |
+
return self.prompt | llm | self.parser
|
models/route_explainer.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# templates
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
from typing import Dict, List
|
5 |
+
from models.prompts.identify_question import Template4IdentifyQuestion
|
6 |
+
from models.prompts.generate_explanation import Template4GenerateExplanation
|
7 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
8 |
+
from langchain.schema import AIMessage
|
9 |
+
import utils.util_app as util_app
|
10 |
+
|
11 |
+
class StreamingChatCallbackHandler(BaseCallbackHandler):
|
12 |
+
def __init__(self):
|
13 |
+
pass
|
14 |
+
|
15 |
+
def on_llm_start(self, *args, **kwargs):
|
16 |
+
self.container = st.empty()
|
17 |
+
self.text = ""
|
18 |
+
|
19 |
+
def on_llm_new_token(self, token: str, *args, **kwargs):
|
20 |
+
self.text += token
|
21 |
+
self.container.markdown(
|
22 |
+
body=self.text,
|
23 |
+
unsafe_allow_html=False,
|
24 |
+
)
|
25 |
+
|
26 |
+
def on_llm_end(self, response: str, *args, **kwargs):
|
27 |
+
self.container.markdown(
|
28 |
+
body=response.generations[0][0].text,
|
29 |
+
unsafe_allow_html=False,
|
30 |
+
)
|
31 |
+
|
32 |
+
class RouteExplainer():
|
33 |
+
template_identify_question = Template4IdentifyQuestion()
|
34 |
+
template_generate_explanation = Template4GenerateExplanation()
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
llm,
|
38 |
+
cf_generator,
|
39 |
+
classifier) -> None:
|
40 |
+
assert cf_generator.problem == classifier.problem, "Problem type of cf_generator and predictor should coincide!"
|
41 |
+
self.coord_dim = 2
|
42 |
+
self.problem = cf_generator.problem
|
43 |
+
self.cf_generator = cf_generator
|
44 |
+
self.classifier = classifier
|
45 |
+
self.actual_route = None
|
46 |
+
self.cf_route = None
|
47 |
+
# templates
|
48 |
+
self.question_extractor = self.template_identify_question.sandwiches(llm)
|
49 |
+
self.explanation_generator = self.template_generate_explanation.sandwiches(llm)
|
50 |
+
|
51 |
+
#----------------
|
52 |
+
# whole pipeline
|
53 |
+
#----------------
|
54 |
+
def generate_explanation(self,
|
55 |
+
tour_list,
|
56 |
+
whynot_question: str,
|
57 |
+
actual_routes: list,
|
58 |
+
actual_labels: list,
|
59 |
+
node_feats: dict,
|
60 |
+
dist_matrix: np.array) -> str:
|
61 |
+
#--------------------------------
|
62 |
+
# define why & why-not questions
|
63 |
+
#--------------------------------
|
64 |
+
route_info_text = self.get_route_info_text(tour_list, actual_routes)
|
65 |
+
inputs = self.question_extractor.invoke({
|
66 |
+
"whynot_question": whynot_question,
|
67 |
+
"route_info": route_info_text
|
68 |
+
})
|
69 |
+
util_app.stream_words(inputs["summary"] + " " + inputs["intent"])
|
70 |
+
st.session_state.chat_history.append(AIMessage(content=inputs["summary"] + inputs["intent"]))
|
71 |
+
if not inputs["success"]:
|
72 |
+
return ""
|
73 |
+
|
74 |
+
#----------------------
|
75 |
+
# validate the CF edge
|
76 |
+
#----------------------
|
77 |
+
is_cf_edge_feasible, reason = self.validate_cf_edge(node_feats,
|
78 |
+
dist_matrix,
|
79 |
+
actual_routes[0],
|
80 |
+
inputs["cf_step"],
|
81 |
+
inputs["cf_visit"]-1)
|
82 |
+
# exception
|
83 |
+
if not is_cf_edge_feasible:
|
84 |
+
util_app.stream_words(reason)
|
85 |
+
return reason
|
86 |
+
|
87 |
+
#---------------------
|
88 |
+
# generate a cf route
|
89 |
+
#---------------------
|
90 |
+
cf_routes = self.cf_generator(actual_routes,
|
91 |
+
vehicle_id=0,
|
92 |
+
cf_step=inputs["cf_step"],
|
93 |
+
cf_next_node_id=inputs["cf_visit"]-1,
|
94 |
+
node_feats=node_feats,
|
95 |
+
dist_matrix=dist_matrix)
|
96 |
+
st.session_state.generated_cf_route = True
|
97 |
+
st.session_state.close_chat = True
|
98 |
+
st.session_state.cf_step = inputs["cf_step"]
|
99 |
+
|
100 |
+
#--------------------------------------
|
101 |
+
# classify the intentions of each edge
|
102 |
+
#--------------------------------------
|
103 |
+
cf_labels = self.classifier(self.classifier.get_inputs(cf_routes,
|
104 |
+
0,
|
105 |
+
node_feats,
|
106 |
+
dist_matrix))
|
107 |
+
st.session_state.cf_routes = cf_routes
|
108 |
+
st.session_state.cf_labels = cf_labels
|
109 |
+
|
110 |
+
#-------------------------------------
|
111 |
+
# generate a constrastive explanation
|
112 |
+
#-------------------------------------
|
113 |
+
comparison_results = self.get_comparison_results(question_summary=inputs["summary"],
|
114 |
+
tour_list=tour_list,
|
115 |
+
actual_routes=actual_routes,
|
116 |
+
actual_labels=actual_labels,
|
117 |
+
cf_routes=cf_routes,
|
118 |
+
cf_labels=cf_labels,
|
119 |
+
cf_step=inputs["cf_step"])
|
120 |
+
|
121 |
+
explanation = self.explanation_generator.invoke({
|
122 |
+
"comparison_results": comparison_results,
|
123 |
+
"intent": inputs["intent"]
|
124 |
+
}, config={"callbacks": [StreamingChatCallbackHandler()]})
|
125 |
+
|
126 |
+
return explanation
|
127 |
+
|
128 |
+
#-------------------------
|
129 |
+
# for exctracting inputs
|
130 |
+
#-------------------------
|
131 |
+
def get_route_info_text(self, tour_list, routes) -> str:
|
132 |
+
route_info = ""
|
133 |
+
# nodes
|
134 |
+
route_info += "Nodes(node id, name): "
|
135 |
+
for i, destination in enumerate(tour_list):
|
136 |
+
if i != len(tour_list) - 1:
|
137 |
+
route_info += f"({i+1}, {destination['name']}), "
|
138 |
+
else:
|
139 |
+
route_info += f"({i+1}, {destination['name']})\n"
|
140 |
+
|
141 |
+
# routes
|
142 |
+
route_info += "Route: "
|
143 |
+
for i, node_id in enumerate(routes[0]):
|
144 |
+
if i == 0:
|
145 |
+
route_info += f"{tour_list[node_id]['name']} "
|
146 |
+
else:
|
147 |
+
route_info += f"> (step {i}) > {tour_list[node_id]['name']})"
|
148 |
+
if i == len(routes[0]) - 1:
|
149 |
+
route_info += "\n"
|
150 |
+
else:
|
151 |
+
route_info += " "
|
152 |
+
return route_info
|
153 |
+
|
154 |
+
#--------------------------
|
155 |
+
# for validating a CF edge
|
156 |
+
#--------------------------
|
157 |
+
def validate_cf_edge(self,
|
158 |
+
node_feats: Dict[str, np.array],
|
159 |
+
dist_matrix: np.array,
|
160 |
+
route: List[int],
|
161 |
+
cf_step: int,
|
162 |
+
cf_visit: int) -> bool:
|
163 |
+
# calc current time
|
164 |
+
curr_time = node_feats["time_window"][route[0]][0] # start point's open time
|
165 |
+
for step in range(1, cf_step):
|
166 |
+
curr_node_id = route[step-1]
|
167 |
+
next_node_id = route[step]
|
168 |
+
curr_time += node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
|
169 |
+
curr_time = max(curr_time, node_feats["time_window"][next_node_id][0]) # waiting
|
170 |
+
|
171 |
+
# validate the cf edge
|
172 |
+
curr_node_id = route[cf_step-1]
|
173 |
+
next_node_id = cf_visit
|
174 |
+
next_node_close_time = node_feats["time_window"][next_node_id][1]
|
175 |
+
arrival_time = curr_time + node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
|
176 |
+
if next_node_close_time < arrival_time:
|
177 |
+
exceed_time = (arrival_time - next_node_close_time)
|
178 |
+
return False, f"Oops, your CF edge is infeasible because it does not meet the destination's close time by {util_app.add_time_unit(exceed_time)}."
|
179 |
+
else:
|
180 |
+
return True, "The CF edge is feasible!"
|
181 |
+
|
182 |
+
#-------------------------------
|
183 |
+
# for generating an explanation
|
184 |
+
#-------------------------------
|
185 |
+
def get_comparison_results(self,
|
186 |
+
tour_list,
|
187 |
+
question_summary,
|
188 |
+
actual_routes: List[List[int]],
|
189 |
+
actual_labels: List[List[int]],
|
190 |
+
cf_routes: List[List[int]],
|
191 |
+
cf_labels: List[List[int]],
|
192 |
+
cf_step: int) -> str:
|
193 |
+
comparison_results = "Question:\n" + question_summary + "\n"
|
194 |
+
comparison_results += "Actual route:\n" + \
|
195 |
+
self.get_route_info(tour_list, actual_routes[0], actual_labels[0], cf_step-1, "actual") + \
|
196 |
+
self.get_representative_values(actual_routes[0], actual_labels[0], cf_step-1, "actual")
|
197 |
+
comparison_results += "CF route:\n" + \
|
198 |
+
self.get_route_info(tour_list, cf_routes[0], cf_labels[0], cf_step-1, "CF") + \
|
199 |
+
self.get_representative_values(cf_routes[0], cf_labels[0], cf_step-1, "CF")
|
200 |
+
comparison_results += "Difference between two routes:\n" + self.get_diff(cf_step-1, actual_routes[0], cf_routes[0])
|
201 |
+
comparison_results += "Planed desination information:\n" + self.get_node_info()
|
202 |
+
return comparison_results
|
203 |
+
|
204 |
+
def get_route_info(self,
|
205 |
+
tour_list,
|
206 |
+
route: List[int],
|
207 |
+
label: List[int],
|
208 |
+
ex_step: int,
|
209 |
+
type: str) -> str:
|
210 |
+
def get_labelname(label_number):
|
211 |
+
return "route_len" if label_number == 0 else "time_window"
|
212 |
+
route_info = "- route: "
|
213 |
+
for i, node_id in enumerate(route):
|
214 |
+
if i == ex_step and i != len(route) - 1:
|
215 |
+
if type == "actual":
|
216 |
+
edge_label = {get_labelname(label[i])}
|
217 |
+
else:
|
218 |
+
edge_label = "user_preference"
|
219 |
+
route_info += f"{tour_list[node_id]['name']} > ({type} edge: {edge_label}) > "
|
220 |
+
elif i != len(route) - 1:
|
221 |
+
route_info += f"{tour_list[node_id]['name']} > ({get_labelname(label[i])}) > "
|
222 |
+
else:
|
223 |
+
route_info += f"{tour_list[node_id]['name']}\n"
|
224 |
+
return route_info
|
225 |
+
|
226 |
+
def get_representative_values(self, route, labels, ex_step, type) -> str:
|
227 |
+
time_window_ratio = self.get_intention_ratio(1, labels, ex_step) * 100
|
228 |
+
route_len_ratio = self.get_intention_ratio(0, labels, ex_step) * 100
|
229 |
+
return f"- short-term effect (immediate travel time): {self.get_immediate_state(route, ex_step)//60} minutes\n- long-term effect (total travel time): {self.get_route_length(route)//60} minutes\n- missed nodes: {self.get_infeasible_node_name(route)}\n- edge-intention ratio after the {type} edge: time_window {time_window_ratio: .1f}%, route_len {route_len_ratio: .1f}%"
|
230 |
+
|
231 |
+
def get_immediate_state(self, route, ex_step) -> str:
|
232 |
+
return st.session_state.dist_matrix[route[ex_step]][route[ex_step+1]]
|
233 |
+
|
234 |
+
def get_route_length(self, route) -> float:
|
235 |
+
route_length = 0.0
|
236 |
+
for i in range(len(route)-1):
|
237 |
+
route_length += st.session_state.dist_matrix[route[i]][route[i+1]]
|
238 |
+
return route_length
|
239 |
+
|
240 |
+
def get_infeasible_nodes(self, route) -> int:
|
241 |
+
return len(route) - (len(st.session_state.dist_matrix) - 1)
|
242 |
+
|
243 |
+
def get_infeasible_node_name(self, route) -> str:
|
244 |
+
if len(route) == len(st.session_state.dist_matrix) - 1:
|
245 |
+
return "none"
|
246 |
+
else:
|
247 |
+
num_nodes = np.arange(len(st.session_state.dist_matrix))
|
248 |
+
for node_id in route:
|
249 |
+
num_nodes = num_nodes[num_nodes != node_id]
|
250 |
+
return ",".join([st.session_state.tour_list[node_id]["name"] for node_id in num_nodes])
|
251 |
+
|
252 |
+
def get_intention_ratio(self,
|
253 |
+
intention: int,
|
254 |
+
labels: List[int],
|
255 |
+
ex_step: int) -> float:
|
256 |
+
np_labels = np.array(labels)
|
257 |
+
return np.sum(np_labels[ex_step:] == intention) / len(labels[ex_step:])
|
258 |
+
|
259 |
+
def get_diff(self, ex_step, actual_route, cf_route) -> str:
|
260 |
+
def get_str(effect: float):
|
261 |
+
long_effect_str = "The actual route increases it by" if effect > 0 else "The actual route reduces it by"
|
262 |
+
long_effect_str += util_app.add_time_unit(abs(effect))
|
263 |
+
return long_effect_str
|
264 |
+
|
265 |
+
def get_str2(num_nodes: int, num_missed_nodes):
|
266 |
+
if num_nodes < 0:
|
267 |
+
num_nodes_str = f"The actual route visits {abs(num_nodes)} more nodes"
|
268 |
+
elif num_nodes == 0:
|
269 |
+
if num_missed_nodes == 0:
|
270 |
+
num_nodes_str = f"Both routes missed no node,"
|
271 |
+
else:
|
272 |
+
num_nodes_str = f"Both routes missed the same number of nodes ({abs(num_missed_nodes)} node(s))"
|
273 |
+
else:
|
274 |
+
num_nodes_str = f"The actual route visits {abs(num_nodes)} less nodes"
|
275 |
+
return num_nodes_str
|
276 |
+
|
277 |
+
# short/long-term effects
|
278 |
+
short_effect = self.get_immediate_state(actual_route, ex_step) - self.get_immediate_state(cf_route, ex_step)
|
279 |
+
long_effect = self.get_route_length(actual_route) - self.get_route_length(cf_route)
|
280 |
+
short_effect_str = get_str(short_effect)
|
281 |
+
long_effect_str = get_str(long_effect)
|
282 |
+
|
283 |
+
# missed nodes
|
284 |
+
missed_nodes = self.get_infeasible_nodes(actual_route) - self.get_infeasible_nodes(cf_route)
|
285 |
+
missed_nodes_str = get_str2(missed_nodes, self.get_infeasible_nodes(actual_route))
|
286 |
+
|
287 |
+
return f"- short-term effect: {short_effect_str}\n - long-term effect: {long_effect_str}\n- missed nodes: {missed_nodes_str}\n"
|
288 |
+
|
289 |
+
def get_node_info(self) -> str:
|
290 |
+
node_info = ""
|
291 |
+
for i in range(len(st.session_state.df_tour)):
|
292 |
+
node_info += f"- {st.session_state.df_tour['destination'][i]}: {st.session_state.df_tour['remarks'][i]}\n"
|
293 |
+
return node_info
|
models/solvers/concorde/concorde.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import scipy
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import datetime
|
6 |
+
import subprocess
|
7 |
+
import models.solvers.concorde.concorde_utils as concorde_utils
|
8 |
+
import glob
|
9 |
+
import random
|
10 |
+
|
11 |
+
class ConcordeTSP(nn.Module):
|
12 |
+
def __init__(self, large_value=1e+6, scaling=False, random_seed=1234, solver_dir="models/solvers/concorde/src/TSP", io_dir="concorde_io_files"):
|
13 |
+
self.random_seed = random_seed
|
14 |
+
self.large_value = large_value
|
15 |
+
self.scaling = scaling
|
16 |
+
self.solver_dir = solver_dir
|
17 |
+
self.io_dir = io_dir
|
18 |
+
self.redirector_stdout = concorde_utils.Redirector(fd=concorde_utils.STDOUT)
|
19 |
+
self.redirector_stderr = concorde_utils.Redirector(fd=concorde_utils.STDERR)
|
20 |
+
os.makedirs(io_dir, exist_ok=True)
|
21 |
+
|
22 |
+
def get_instance_name(self):
|
23 |
+
now = datetime.datetime.now()
|
24 |
+
random_value = random.random() # for avoiding duplicated file name
|
25 |
+
instance_name = f"{os.getpid()}_{random_value}_{now.strftime('%Y%m%d_%H%M%S%f')}"
|
26 |
+
return instance_name
|
27 |
+
|
28 |
+
def write_instance(self, node_feats, fixed_paths=None, instance_name=None):
|
29 |
+
if instance_name is None:
|
30 |
+
instance_name = self.get_instance_name()
|
31 |
+
instance_fname = f"{self.io_dir}/{instance_name}.tsp"
|
32 |
+
tour_fname = f"{self.io_dir}/{instance_name}.sol"
|
33 |
+
with open(instance_fname, "w") as f:
|
34 |
+
f.write(f"NAME : {instance_name}\n")
|
35 |
+
f.write(f"TYPE : TSP\n")
|
36 |
+
f.write(f"DIMENSION : {len(node_feats['coords'])}\n")
|
37 |
+
self.write_data(f, node_feats, fixed_paths)
|
38 |
+
f.write("EOF\n")
|
39 |
+
return instance_fname, tour_fname
|
40 |
+
|
41 |
+
def write_data(self, f, node_feats, fixed_paths=None):
|
42 |
+
coords = node_feats["coords"]
|
43 |
+
if fixed_paths is None:
|
44 |
+
f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
|
45 |
+
f.write("NODE_COORD_SECTION\n")
|
46 |
+
for i in range(len(coords)):
|
47 |
+
f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
|
48 |
+
else:
|
49 |
+
f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
|
50 |
+
f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
|
51 |
+
f.write("EDGE_WEIGHT_SECTION\n")
|
52 |
+
dist = scipy.spatial.distance.cdist(coords, coords).round().astype(np.int64)
|
53 |
+
for i in range(len(fixed_paths)):
|
54 |
+
curr_id = fixed_paths[i]
|
55 |
+
if i != 0 and i != len(fixed_paths) - 1:
|
56 |
+
# NOTE: concorde TSP seems to use int32, so 1e+9 occurs overflow.
|
57 |
+
# 1e+8 could also do the same when N (tour length) is large.
|
58 |
+
dist[curr_id, :] = 1e+8; dist[:, curr_id] = 1e+8
|
59 |
+
if i != 0:
|
60 |
+
prev_id = fixed_paths[i - 1]
|
61 |
+
dist[prev_id, curr_id] = 0; dist[curr_id, prev_id] = 0
|
62 |
+
if i != len(fixed_paths) - 1:
|
63 |
+
next_id = fixed_paths[i + 1]
|
64 |
+
dist[curr_id, next_id] = 0; dist[next_id, curr_id] = 0
|
65 |
+
f.write("\n".join([
|
66 |
+
" ".join(map(str, row))
|
67 |
+
for row in dist
|
68 |
+
]))
|
69 |
+
|
70 |
+
def solve(self, node_feats, fixed_paths=None, instance_name=None):
|
71 |
+
if self.scaling:
|
72 |
+
node_feats = self.preprocess_data(node_feats)
|
73 |
+
self.redirector_stdout.start()
|
74 |
+
self.redirector_stderr.start()
|
75 |
+
instance_fname, tour_fname = self.write_instance(node_feats, fixed_paths, instance_name)
|
76 |
+
subprocess.run(f"{self.solver_dir}/concorde -o {tour_fname} -x {instance_fname}", shell=True) # run Concorde
|
77 |
+
self.redirector_stderr.stop()
|
78 |
+
self.redirector_stdout.stop()
|
79 |
+
tours = self.read_tour(tour_fname)
|
80 |
+
# remove dump (?) files
|
81 |
+
try:
|
82 |
+
os.remove(instance_fname); os.remove(tour_fname)
|
83 |
+
except OSError as e:
|
84 |
+
pass
|
85 |
+
fname_list = glob.glob("*.sol")
|
86 |
+
fname_list.extend(glob.glob("*.res"))
|
87 |
+
for fname in fname_list:
|
88 |
+
try:
|
89 |
+
os.remove(fname)
|
90 |
+
except OSError as e:
|
91 |
+
# do nothing
|
92 |
+
pass
|
93 |
+
# subprocess.run(f"rm {instance_name}.sol", shell=True)
|
94 |
+
return tours
|
95 |
+
|
96 |
+
def read_tour(self, tour_fname):
|
97 |
+
"""
|
98 |
+
Parameters
|
99 |
+
----------
|
100 |
+
tour_fname: str
|
101 |
+
path to an output tour
|
102 |
+
|
103 |
+
Returns
|
104 |
+
-------
|
105 |
+
tour: 2d list [num_vehicles(1) x seq_length]
|
106 |
+
"""
|
107 |
+
if not os.path.exists(tour_fname): # fails to solve the instance
|
108 |
+
return
|
109 |
+
|
110 |
+
tour = []
|
111 |
+
with open(tour_fname, "r") as f:
|
112 |
+
for i, line in enumerate(f):
|
113 |
+
if i == 0:
|
114 |
+
continue
|
115 |
+
read_tour = line.split()
|
116 |
+
tour.extend(read_tour)
|
117 |
+
tour.append(tour[0])
|
118 |
+
return [list(map(int, tour))]
|
119 |
+
|
120 |
+
def preprocess_data(self, node_feats):
|
121 |
+
# convert float to integer approximately
|
122 |
+
return {
|
123 |
+
key: (node_feat * self.large_value).astype(np.int64)
|
124 |
+
if key == "coords" else
|
125 |
+
node_feat
|
126 |
+
for key, node_feat in node_feats.items()
|
127 |
+
}
|
models/solvers/concorde/concorde_utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
# for stopping std out from concorde solver
|
6 |
+
# https://github.com/machine-reasoning-ufrgs/TSP-GNN/blob/master/redirector.py
|
7 |
+
STDOUT = 1
|
8 |
+
STDERR = 2
|
9 |
+
|
10 |
+
class Redirector(object):
|
11 |
+
def __init__(self, fd=STDOUT):
|
12 |
+
self.fd = fd
|
13 |
+
self.started = False
|
14 |
+
|
15 |
+
def start(self):
|
16 |
+
if not self.started:
|
17 |
+
self.tmpfd, self.tmpfn = tempfile.mkstemp()
|
18 |
+
|
19 |
+
self.oldhandle = os.dup(self.fd)
|
20 |
+
os.dup2(self.tmpfd, self.fd)
|
21 |
+
os.close(self.tmpfd)
|
22 |
+
|
23 |
+
self.started = True
|
24 |
+
|
25 |
+
def flush(self):
|
26 |
+
if self.fd == STDOUT:
|
27 |
+
sys.stdout.flush()
|
28 |
+
elif self.fd == STDERR:
|
29 |
+
sys.stderr.flush()
|
30 |
+
|
31 |
+
def stop(self):
|
32 |
+
if self.started:
|
33 |
+
self.flush()
|
34 |
+
os.dup2(self.oldhandle, self.fd)
|
35 |
+
os.close(self.oldhandle)
|
36 |
+
tmpr = open(self.tmpfn, 'rb')
|
37 |
+
output = tmpr.read()
|
38 |
+
tmpr.close() # this also closes self.tmpfd
|
39 |
+
os.unlink(self.tmpfn)
|
40 |
+
|
41 |
+
self.started = False
|
42 |
+
return output
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
|
46 |
+
class RedirectorOneFile(object):
|
47 |
+
def __init__(self, fd=STDOUT):
|
48 |
+
self.fd = fd
|
49 |
+
self.started = False
|
50 |
+
self.inited = False
|
51 |
+
|
52 |
+
self.initialize()
|
53 |
+
|
54 |
+
def initialize(self):
|
55 |
+
if not self.inited:
|
56 |
+
self.tmpfd, self.tmpfn = tempfile.mkstemp()
|
57 |
+
self.pos = 0
|
58 |
+
self.tmpr = open(self.tmpfn, 'rb')
|
59 |
+
self.inited = True
|
60 |
+
|
61 |
+
def start(self):
|
62 |
+
if not self.started:
|
63 |
+
self.oldhandle = os.dup(self.fd)
|
64 |
+
os.dup2(self.tmpfd, self.fd)
|
65 |
+
self.started = True
|
66 |
+
|
67 |
+
def flush(self):
|
68 |
+
if self.fd == STDOUT:
|
69 |
+
sys.stdout.flush()
|
70 |
+
elif self.fd == STDERR:
|
71 |
+
sys.stderr.flush()
|
72 |
+
|
73 |
+
def stop(self):
|
74 |
+
if self.started:
|
75 |
+
self.flush()
|
76 |
+
os.dup2(self.oldhandle, self.fd)
|
77 |
+
os.close(self.oldhandle)
|
78 |
+
output = self.tmpr.read()
|
79 |
+
self.pos = self.tmpr.tell()
|
80 |
+
self.started = False
|
81 |
+
return output
|
82 |
+
else:
|
83 |
+
return None
|
84 |
+
|
85 |
+
def close(self):
|
86 |
+
if self.inited:
|
87 |
+
self.flush()
|
88 |
+
self.tmpr.close() # this also closes self.tmpfd
|
89 |
+
os.unlink(self.tmpfn)
|
90 |
+
self.inited = False
|
91 |
+
return output
|
92 |
+
else:
|
93 |
+
return None
|
models/solvers/general_solver.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.solvers.ortools.ortools import ORTools
|
3 |
+
from models.solvers.lkh.lkh import LKH
|
4 |
+
from models.solvers.concorde.concorde import ConcordeTSP
|
5 |
+
|
6 |
+
class GeneralSolver(nn.Module):
|
7 |
+
def __init__(self, problem, solver_type, large_value=1e+6, scaling=True):
|
8 |
+
super().__init__()
|
9 |
+
self.problem = problem
|
10 |
+
self.large_value = large_value
|
11 |
+
self.scaling = scaling
|
12 |
+
self.solver_type = solver_type
|
13 |
+
supported_problem = {
|
14 |
+
"ortools": ["tsp", "tsptw", "pctsp", "pctsptw", "cvrp", "cvrptw"],
|
15 |
+
"lkh": ["tsp", "tsptw", "cvrp", "cvrptw"],
|
16 |
+
"concorde": ["tsp"]
|
17 |
+
}
|
18 |
+
# validate solver_type & problem
|
19 |
+
assert solver_type in supported_problem.keys(), f"Invalid solver type: {solver_type}. Please select from {supported_problem.keys()}"
|
20 |
+
assert problem in supported_problem[solver_type], f"{solver_type} does not support {problem}."
|
21 |
+
self.solver = self.get_solver(problem, solver_type)
|
22 |
+
|
23 |
+
def change_solver(self, problem, solver_type):
|
24 |
+
if self.solver_type != solver_type or self.problem != problem:
|
25 |
+
self.problem = problem
|
26 |
+
self.solver_type = solver_type
|
27 |
+
self.solver = self.get_solver(problem, solver_type)
|
28 |
+
|
29 |
+
def get_solver(self, problem, solver_type):
|
30 |
+
if solver_type == "ortools":
|
31 |
+
return ORTools(problem, self.large_value, self.scaling)
|
32 |
+
elif solver_type == "lkh":
|
33 |
+
return LKH(problem, self.large_value, self.scaling)
|
34 |
+
elif solver_type == "concorde":
|
35 |
+
assert problem == "tsp", "Concorde solver supports only TSP."
|
36 |
+
return ConcordeTSP(self.large_value, self.scaling)
|
37 |
+
else:
|
38 |
+
assert False, f"Invalid solver type: {solver_type}"
|
39 |
+
|
40 |
+
def solve(self, node_feats, fixed_paths=None, dist_matrix=None, instance_name=None):
|
41 |
+
if isinstance(self.solver, ORTools):
|
42 |
+
return self.solver.solve(node_feats, fixed_paths, dist_matrix, instance_name)
|
43 |
+
else:
|
44 |
+
return self.solver.solve(node_feats, fixed_paths, instance_name)
|
models/solvers/lkh/lkh.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.solvers.lkh.lkh_tsp import LKHTSP
|
3 |
+
from models.solvers.lkh.lkh_tsptw import LKHTSPTW
|
4 |
+
from models.solvers.lkh.lkh_cvrp import LKHCVRP
|
5 |
+
from models.solvers.lkh.lkh_cvrptw import LKHCVRPTW
|
6 |
+
|
7 |
+
class LKH(nn.Module):
|
8 |
+
def __init__(self, problem, large_value=1e+6, scaling=False, max_trials=10, seed=1234, lkh_dir="models/solvers/lkh/src", io_dir="lkh_io_files"):
|
9 |
+
super().__init__()
|
10 |
+
self.probelm = problem
|
11 |
+
if problem == "tsp":
|
12 |
+
self.lkh = LKHTSP(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
13 |
+
elif problem == "tsptw":
|
14 |
+
self.lkh = LKHTSPTW(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
15 |
+
elif problem == "cvrp":
|
16 |
+
self.lkh = LKHCVRP(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
17 |
+
elif problem == "cvrptw":
|
18 |
+
self.lkh = LKHCVRPTW(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
19 |
+
|
20 |
+
def solve(self, node_feats, fixed_paths=None, instance_name=None):
|
21 |
+
return self.lkh.solve(node_feats, fixed_paths, instance_name)
|
models/solvers/lkh/lkh_base.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datetime
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from subprocess import check_call
|
6 |
+
|
7 |
+
NODE_ID_OFFSET = 1
|
8 |
+
|
9 |
+
class LKHBase(nn.Module):
|
10 |
+
def __init__(self, problem, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh", io_dir="lkh_io_files"):
|
11 |
+
super().__init__()
|
12 |
+
self.coord_dim = 2
|
13 |
+
self.problem = problem
|
14 |
+
self.large_value = large_value
|
15 |
+
self.scaling = scaling
|
16 |
+
self.max_trials = max_trials
|
17 |
+
self.seed = seed
|
18 |
+
|
19 |
+
# I/O file settings
|
20 |
+
self.lkh_dir = lkh_dir
|
21 |
+
self.io_dir = io_dir
|
22 |
+
self.instance_path = f"{io_dir}/{self.problem}/instance"
|
23 |
+
self.param_path = f"{io_dir}/{self.problem}/param"
|
24 |
+
self.tour_path = f"{io_dir}/{self.problem}/tour"
|
25 |
+
self.log_path = f"{io_dir}/{self.problem}/log"
|
26 |
+
os.makedirs(self.instance_path, exist_ok=True)
|
27 |
+
os.makedirs(self.param_path, exist_ok=True)
|
28 |
+
os.makedirs(self.tour_path, exist_ok=True)
|
29 |
+
os.makedirs(self.log_path, exist_ok=True)
|
30 |
+
|
31 |
+
def solve(self, node_feats, fixed_paths=None, instance_name=None):
|
32 |
+
instance_fname = self.write_instance(node_feats, fixed_paths, instance_name)
|
33 |
+
param_fname, tour_fname, log_fname = self.write_para(instance_fname, instance_name)
|
34 |
+
with open(log_fname, "w") as f:
|
35 |
+
check_call([f"{self.lkh_dir}/LKH", param_fname], stdout=f) # run LKH
|
36 |
+
tours = self.read_tour(node_feats, tour_fname)
|
37 |
+
# clean intermidiate files
|
38 |
+
try:
|
39 |
+
os.remove(instance_fname); os.remove(param_fname); os.remove(tour_fname); os.remove(log_fname)
|
40 |
+
except:
|
41 |
+
pass
|
42 |
+
return tours
|
43 |
+
|
44 |
+
def get_instance_name(self):
|
45 |
+
now = datetime.datetime.now()
|
46 |
+
instance_name = f"{os.getpid()}-{now.strftime('%Y%m%d_%H%M%S%f')}"
|
47 |
+
return instance_name
|
48 |
+
|
49 |
+
def write_instance(self, node_feats, fixed_paths=None, instance_name=None):
|
50 |
+
if instance_name is None:
|
51 |
+
instance_name = self.get_instance_name()
|
52 |
+
instance_fname = f"{self.instance_path}/{instance_name}.{self.problem}"
|
53 |
+
with open(instance_fname, "w") as f:
|
54 |
+
f.write(f"NAME : {instance_name}\n")
|
55 |
+
f.write(f"TYPE : {self.problem.upper()}\n")
|
56 |
+
f.write(f"DIMENSION : {len(node_feats['coords'])}\n")
|
57 |
+
self.write_data(node_feats, f)
|
58 |
+
if fixed_paths is not None and len(fixed_paths) > 1:
|
59 |
+
fixed_paths = fixed_paths.copy()
|
60 |
+
# FIXED_EDGE_SECTION works well in TSP, but it cannot fix edges in TSPTW
|
61 |
+
# EDGE_DATA_SECTION can fix edges in both TSP and TSPTW, but the obtained tour is very poor
|
62 |
+
f.write("FIXED_EDGES_SECTION\n")
|
63 |
+
fixed_paths += NODE_ID_OFFSET # offset node id (node id starts from 1 in TSPLIB)
|
64 |
+
for i in range(len(fixed_paths) - 1):
|
65 |
+
f.write(f"{fixed_paths[i]} {fixed_paths[i+1]}\n")
|
66 |
+
# f.write("EDGE_DATA_FORMAT : EDGE_LIST\n")
|
67 |
+
# f.write("EDGE_DATA_SECTION\n")
|
68 |
+
# avail_edges = self.get_avail_edges(node_feats, fixed_paths)
|
69 |
+
# avail_edges += 1 # offset node id (node id starts from 1 in TSPLIB)
|
70 |
+
# for i in range(len(avail_edges)):
|
71 |
+
# f.write(f"{avail_edges[i][0]} {avail_edges[i][1]}\n")
|
72 |
+
f.write("EOF\n")
|
73 |
+
return instance_fname
|
74 |
+
|
75 |
+
def write_data(self, node_feats, f):
|
76 |
+
raise NotImplementedError
|
77 |
+
|
78 |
+
def get_avail_edges(self, node_feats, fixed_paths):
|
79 |
+
num_nodes = len(node_feats["coords"])
|
80 |
+
avail_edges = []
|
81 |
+
# add fixed edges
|
82 |
+
for i in range(len(fixed_paths) - 1):
|
83 |
+
avail_edges.append([fixed_paths[i], fixed_paths[i + 1]])
|
84 |
+
|
85 |
+
# add rest avaialbel edges
|
86 |
+
visited = np.array([0] * num_nodes)
|
87 |
+
for id in fixed_paths:
|
88 |
+
visited[id] = 1
|
89 |
+
visited[fixed_paths[0]] = 0
|
90 |
+
visited[fixed_paths[-1]] = 0
|
91 |
+
mask = visited < 1
|
92 |
+
node_id = np.arange(num_nodes)
|
93 |
+
feasible_node_id = node_id[mask]
|
94 |
+
for j in range(len(feasible_node_id) - 1):
|
95 |
+
for i in range(j + 1, len(feasible_node_id)):
|
96 |
+
avail_edges.append([feasible_node_id[j], feasible_node_id[i]])
|
97 |
+
return np.array(avail_edges)
|
98 |
+
|
99 |
+
def write_para(self, instance_fname, instance_name=None):
|
100 |
+
if instance_name is None:
|
101 |
+
instance_name = self.get_instance_name()
|
102 |
+
param_fname = f"{self.param_path}/{instance_name}.param"
|
103 |
+
tour_fname = f"{self.tour_path}/{instance_name}.tour"
|
104 |
+
log_fname = f"{self.log_path}/{instance_name}.log"
|
105 |
+
with open(param_fname, "w") as f:
|
106 |
+
f.write(f"PROBLEM_FILE = {instance_fname}\n")
|
107 |
+
f.write(f"MAX_TRIALS = {self.max_trials}\n")
|
108 |
+
f.write("MOVE_TYPE = 5\nPATCHING_C = 3\nPATCHING_A = 2\nRUNS = 1\n")
|
109 |
+
f.write(f"SEED = {self.seed}\n")
|
110 |
+
f.write(f"OUTPUT_TOUR_FILE = {tour_fname}\n")
|
111 |
+
return param_fname, tour_fname, log_fname
|
112 |
+
|
113 |
+
def read_tour(self, node_feats, tour_fname):
|
114 |
+
"""
|
115 |
+
Parameters
|
116 |
+
----------
|
117 |
+
output_filename: str
|
118 |
+
path to a file where optimal tour is written
|
119 |
+
Returns
|
120 |
+
-------
|
121 |
+
tour: 2d list [num_vehicles x seq_length]
|
122 |
+
a set of node ids indicating visit order
|
123 |
+
"""
|
124 |
+
if not os.path.exists(tour_fname):
|
125 |
+
return # found no feasible solution
|
126 |
+
|
127 |
+
with open(tour_fname, "r") as f:
|
128 |
+
tour = []
|
129 |
+
is_tour_section = False
|
130 |
+
for line in f:
|
131 |
+
line = line.strip()
|
132 |
+
if line == "TOUR_SECTION":
|
133 |
+
is_tour_section = True
|
134 |
+
continue
|
135 |
+
if is_tour_section:
|
136 |
+
if line != "-1":
|
137 |
+
tour.append(int(line) - NODE_ID_OFFSET)
|
138 |
+
else:
|
139 |
+
tour.append(tour[0])
|
140 |
+
break
|
141 |
+
# convert 1d -> 2d list
|
142 |
+
num_nodes = len(node_feats["coords"])
|
143 |
+
tour = np.array(tour)
|
144 |
+
# NOTE: node_id >= num_nodes indicates the depot node.
|
145 |
+
# That is because LKH uses dummy nodes of which locations are the same as the depot and demands = -capacity?
|
146 |
+
# I'm not sure where the behavior is documented, but the author of NeuroLKH reads output files like that.
|
147 |
+
# please refer to https://github.com/liangxinedu/NeuroLKH/blob/main/CVRPTWdata_generate.py#L132
|
148 |
+
tour[tour >= num_nodes] = 0
|
149 |
+
# remove subsequent zeros
|
150 |
+
tour = tour[np.diff(np.concatenate(([1], tour))).nonzero()]
|
151 |
+
loc0 = (tour == 0).nonzero()[0]
|
152 |
+
num_vehicles = len(loc0) - 1
|
153 |
+
tours = []
|
154 |
+
for vehicle_id in range(num_vehicles):
|
155 |
+
vehicle_tour = tour[loc0[vehicle_id]:loc0[vehicle_id+1]+1].tolist()
|
156 |
+
tours.append(vehicle_tour)
|
157 |
+
return tours # offset to make the first index 0
|
models/solvers/lkh/lkh_cvrp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
from models.solvers.lkh.lkh_base import LKHBase
|
4 |
+
|
5 |
+
class LKHCVRP(LKHBase):
|
6 |
+
def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
|
7 |
+
problem = "cvrp"
|
8 |
+
super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
9 |
+
|
10 |
+
def write_data(self, node_feats, f):
|
11 |
+
"""
|
12 |
+
Paramters
|
13 |
+
---------
|
14 |
+
node_feats: dict of np.array
|
15 |
+
coords: np.array [num_nodes x coord_dim]
|
16 |
+
demand: np.array [num_nodes x 1]
|
17 |
+
capacity: np.array [1]
|
18 |
+
"""
|
19 |
+
coords = node_feats["coords"]
|
20 |
+
demand = node_feats["demand"]
|
21 |
+
capacity = node_feats["capacity"][0]
|
22 |
+
num_nodes = len(coords)
|
23 |
+
if self.scaling:
|
24 |
+
coords = coords * self.large_value
|
25 |
+
# NOTE: In CVRP, LKH can automatically obtain optimal vehicle size.
|
26 |
+
# However it cannot in CVRPTW (please check lkh_cvrptw.py).
|
27 |
+
# EDGE_WEIGHT_SECTION
|
28 |
+
f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
|
29 |
+
# CAPACITY
|
30 |
+
f.write("CAPACITY : " + str(capacity) + "\n")
|
31 |
+
# NODE_COORD_SECTION
|
32 |
+
f.write("NODE_COORD_SECTION\n")
|
33 |
+
for i in range(num_nodes):
|
34 |
+
f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
|
35 |
+
# DEMAND_SECTION
|
36 |
+
f.write("DEMAND_SECTION\n")
|
37 |
+
for i in range(num_nodes):
|
38 |
+
f.write(f" {i + 1} {str(demand[i])}\n")
|
39 |
+
# DEPOT SECTION
|
40 |
+
f.write("DEPOT_SECTION\n")
|
41 |
+
f.write("1\n")
|
models/solvers/lkh/lkh_cvrptw.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
from models.solvers.lkh.lkh_base import LKHBase
|
4 |
+
|
5 |
+
class LKHCVRPTW(LKHBase):
|
6 |
+
def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
|
7 |
+
problem = "cvrptw"
|
8 |
+
super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
9 |
+
|
10 |
+
def write_data(self, node_feats, f):
|
11 |
+
"""
|
12 |
+
Paramters
|
13 |
+
---------
|
14 |
+
node_feats: dict of np.array
|
15 |
+
coords: np.array [num_nodes x coord_dim]
|
16 |
+
demand: np.array [num_nodes x 1]
|
17 |
+
capacity: np.array [1]
|
18 |
+
time_window: np.array [num_nodes x 2(start, end)]
|
19 |
+
"""
|
20 |
+
coords = node_feats["coords"]
|
21 |
+
demand = node_feats["demand"]
|
22 |
+
capacity = node_feats["capacity"][0]
|
23 |
+
time_window = node_feats["time_window"]
|
24 |
+
num_nodes = len(coords)
|
25 |
+
if self.scaling:
|
26 |
+
coords = coords * self.large_value
|
27 |
+
time_window = time_window * self.large_value
|
28 |
+
# VEHICLES
|
29 |
+
# As the number of unused vehicles is also included to penalty in default,
|
30 |
+
# we have to modify Penalty_CVRTW.c in LKH SRC directory.
|
31 |
+
# Comment out the following part, which corresponds to penaly of unsed vehicles:
|
32 |
+
# 42 if (MTSPMinSize >= 1 && Size < MTSPMinSize)
|
33 |
+
# 43 P += MTSPMinSize - Size;
|
34 |
+
# 44 if (Size > MTSPMaxSize)
|
35 |
+
# 45 P += Size - MTSPMaxSize;
|
36 |
+
# After the modification, we can automatically obtain optimal vehicle size
|
37 |
+
# by setting large vehicle size (e.g. >20) here
|
38 |
+
f.write("VEHICLES : 20\n")
|
39 |
+
# CAPACITY
|
40 |
+
f.write("CAPACITY : " + str(capacity) + "\n")
|
41 |
+
# EDGE_WEIGHT_SECTION
|
42 |
+
f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
|
43 |
+
# NODE_COORD_SECTION
|
44 |
+
f.write("NODE_COORD_SECTION\n")
|
45 |
+
for i in range(num_nodes):
|
46 |
+
f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
|
47 |
+
# DEMAND_SECTION
|
48 |
+
f.write("DEMAND_SECTION\n")
|
49 |
+
for i in range(num_nodes):
|
50 |
+
f.write(f" {i + 1} {str(demand[i])}\n")
|
51 |
+
# TIME_WINDOW_SECTION
|
52 |
+
f.write("TIME_WINDOW_SECTION\n")
|
53 |
+
f.write("\n".join([
|
54 |
+
"{}\t{}\t{}".format(i + 1, l, u)
|
55 |
+
for i, (l, u) in enumerate(time_window)
|
56 |
+
]))
|
57 |
+
# DEPOT SECTION
|
58 |
+
f.write("DEPOT_SECTION\n")
|
59 |
+
f.write("1\n")
|
models/solvers/lkh/lkh_tsp.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.solvers.lkh.lkh_base import LKHBase
|
2 |
+
|
3 |
+
class LKHTSP(LKHBase):
|
4 |
+
def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
|
5 |
+
problem = "tsp"
|
6 |
+
super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
7 |
+
|
8 |
+
def write_data(self, node_feats, f):
|
9 |
+
coords = node_feats["coords"]
|
10 |
+
if self.scaling:
|
11 |
+
coords = coords * self.large_value
|
12 |
+
f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
|
13 |
+
f.write("NODE_COORD_SECTION\n")
|
14 |
+
for i in range(len(coords)):
|
15 |
+
f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
|
models/solvers/lkh/lkh_tsptw.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
from models.solvers.lkh.lkh_base import LKHBase
|
4 |
+
|
5 |
+
class LKHTSPTW(LKHBase):
|
6 |
+
def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
|
7 |
+
problem = "tsptw"
|
8 |
+
super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
|
9 |
+
|
10 |
+
def write_data(self, node_feats, f):
|
11 |
+
coord_dim = 2
|
12 |
+
coords = node_feats["coords"]
|
13 |
+
if self.scaling:
|
14 |
+
coords = coords * self.large_value
|
15 |
+
time_window = node_feats["time_window"].astype(np.int64)
|
16 |
+
dist = scipy.spatial.distance.cdist(coords, coords).round().astype(np.int64)
|
17 |
+
f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
|
18 |
+
f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
|
19 |
+
f.write("EDGE_WEIGHT_SECTION\n")
|
20 |
+
f.write("\n".join([
|
21 |
+
" ".join(map(str, row))
|
22 |
+
for row in dist
|
23 |
+
]))
|
24 |
+
f.write("\n")
|
25 |
+
f.write("TIME_WINDOW_SECTION\n")
|
26 |
+
f.write("\n".join([
|
27 |
+
"{}\t{}\t{}".format(i + 1, l, u)
|
28 |
+
for i, (l, u) in enumerate(time_window)
|
29 |
+
]))
|
30 |
+
f.write("\n")
|
31 |
+
f.write("DEPOT_SECTION\n")
|
32 |
+
f.write("1\n")
|
models/solvers/ortools/ortools.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.solvers.ortools.ortools_tsp import ORToolsTSP
|
3 |
+
from models.solvers.ortools.ortools_tsptw import ORToolsTSPTW
|
4 |
+
from models.solvers.ortools.ortools_pctsp import ORToolsPCTSP
|
5 |
+
from models.solvers.ortools.ortools_pctsptw import ORToolsPCTSPTW
|
6 |
+
from models.solvers.ortools.ortools_cvrp import ORToolsCVRP
|
7 |
+
from models.solvers.ortools.ortools_cvrptw import ORToolsCVRPTW
|
8 |
+
|
9 |
+
class ORTools(nn.Module):
|
10 |
+
def __init__(self, problem, large_value=1e+6, scaling=False):
|
11 |
+
super().__init__()
|
12 |
+
self.coord_dim = 2
|
13 |
+
self.problem = problem
|
14 |
+
self.large_value = large_value
|
15 |
+
self.scaling = scaling
|
16 |
+
self.ortools = self.get_ortools(problem)
|
17 |
+
|
18 |
+
def get_ortools(self, problem):
|
19 |
+
"""
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
problem: str
|
23 |
+
problem type
|
24 |
+
|
25 |
+
Returns
|
26 |
+
-------
|
27 |
+
ortools: ortools for the specified problem
|
28 |
+
"""
|
29 |
+
if problem == "tsp":
|
30 |
+
return ORToolsTSP(self.large_value, self.scaling)
|
31 |
+
elif problem == "tsptw":
|
32 |
+
return ORToolsTSPTW(self.large_value, self.scaling)
|
33 |
+
elif problem == "pctsp":
|
34 |
+
return ORToolsPCTSP(self.large_value, self.scaling)
|
35 |
+
elif problem == "pctsptw":
|
36 |
+
return ORToolsPCTSPTW(self.large_value, self.scaling)
|
37 |
+
elif problem == "cvrp":
|
38 |
+
return ORToolsCVRP(self.large_value, self.scaling)
|
39 |
+
elif problem == "cvrptw":
|
40 |
+
return ORToolsCVRPTW(self.large_value, self.scaling)
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
def solve(self, node_feats, fixed_paths=None, dist_martix=None, instance_name=None):
|
45 |
+
"""
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
node_feats: np.array [num_nodes x node_dim]
|
49 |
+
fixed_paths: np.array [cf_step]
|
50 |
+
scaling: bool
|
51 |
+
whether or not coords are muliplied by a large value
|
52 |
+
to convert float-coods into int-coords
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
tour: np.array [seq_length]
|
57 |
+
"""
|
58 |
+
return self.ortools.solve(node_feats, fixed_paths, dist_martix, instance_name)
|