daisuke.kikuta commited on
Commit
719d0db
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. .streamlit/config.toml +2 -0
  3. Dockerfile +38 -0
  4. LICENSE +125 -0
  5. README.md +45 -0
  6. analyze_dataset.py +71 -0
  7. app.py +455 -0
  8. eval_classifier.py +241 -0
  9. eval_solvers.py +62 -0
  10. generate_cf_dataset.py +145 -0
  11. generate_dataset.py +117 -0
  12. install_solvers.py +68 -0
  13. models/cf_generator.py +155 -0
  14. models/classifiers/general_classifier.py +60 -0
  15. models/classifiers/ground_truth/ground_truth.py +35 -0
  16. models/classifiers/ground_truth/ground_truth_base.py +285 -0
  17. models/classifiers/ground_truth/ground_truth_cvrp.py +14 -0
  18. models/classifiers/ground_truth/ground_truth_cvrptw.py +15 -0
  19. models/classifiers/ground_truth/ground_truth_pctsp.py +16 -0
  20. models/classifiers/ground_truth/ground_truth_pctsptw.py +15 -0
  21. models/classifiers/ground_truth/ground_truth_tsptw.py +14 -0
  22. models/classifiers/meaningless_models.py +62 -0
  23. models/classifiers/nn_classifiers/attention_graph_encoder.py +95 -0
  24. models/classifiers/nn_classifiers/decoders/lstm_decoder.py +57 -0
  25. models/classifiers/nn_classifiers/decoders/mha_decoder.py +74 -0
  26. models/classifiers/nn_classifiers/decoders/mlp_decoder.py +50 -0
  27. models/classifiers/nn_classifiers/encoders/attn_edge_encoder.py +81 -0
  28. models/classifiers/nn_classifiers/encoders/concat_edge_encoder.py +63 -0
  29. models/classifiers/nn_classifiers/encoders/max_readout.py +57 -0
  30. models/classifiers/nn_classifiers/encoders/mean_readout.py +57 -0
  31. models/classifiers/nn_classifiers/encoders/mha_node_encoder.py +63 -0
  32. models/classifiers/nn_classifiers/encoders/mlp_node_encoder.py +63 -0
  33. models/classifiers/nn_classifiers/nn_classifier.py +156 -0
  34. models/classifiers/predictor.py +203 -0
  35. models/classifiers/rule_based_models.py +150 -0
  36. models/loss_functions.py +129 -0
  37. models/prompts/generate_explanation.py +110 -0
  38. models/prompts/identify_question.py +76 -0
  39. models/prompts/template_json_base.py +27 -0
  40. models/route_explainer.py +293 -0
  41. models/solvers/concorde/concorde.py +127 -0
  42. models/solvers/concorde/concorde_utils.py +93 -0
  43. models/solvers/general_solver.py +44 -0
  44. models/solvers/lkh/lkh.py +21 -0
  45. models/solvers/lkh/lkh_base.py +157 -0
  46. models/solvers/lkh/lkh_cvrp.py +41 -0
  47. models/solvers/lkh/lkh_cvrptw.py +59 -0
  48. models/solvers/lkh/lkh_tsp.py +15 -0
  49. models/solvers/lkh/lkh_tsptw.py +32 -0
  50. models/solvers/ortools/ortools.py +58 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.png
3
+ lkh_io_files/
4
+ concorde_io_files/
5
+ data/
6
+ data_generator/
7
+ checkpoints/
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [server]
2
+ enableStaticServing = true
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
2
+
3
+ RUN echo "Building docker image"
4
+
5
+ RUN apt-get -y update && \
6
+ apt-get install -y \
7
+ curl \
8
+ build-essential \
9
+ git \
10
+ vim \
11
+ tmux
12
+
13
+ # jupyter-notebook & lab
14
+ RUN python3 -m pip install jupyter
15
+ RUN python3 -m pip install jupyterlab
16
+
17
+ # LLMs
18
+ RUN python3 -m pip install openai
19
+ RUN python3 -m pip install tiktoken
20
+ RUN python3 -m pip install langchain
21
+
22
+ # Web app
23
+ RUN python3 -m pip install streamlit
24
+ RUN python3 -m pip install streamlit-folium
25
+ RUN python3 -m pip install folium
26
+
27
+ # Google Map API
28
+ RUN python3 -m pip install googlemaps
29
+
30
+ # OR-tools
31
+ RUN python3 -m pip install ortools
32
+
33
+ # other convenient packages
34
+ RUN python3 -m pip install torchmetrics
35
+ RUN python3 -m pip install scipy
36
+ RUN python3 -m pip install pandas
37
+ RUN python3 -m pip install matplotlib
38
+ RUN python3 -m pip install tqdm
LICENSE ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SOFTWARE LICENSE AGREEMENT FOR EVALUATION
2
+
3
+ This SOFTWARE LICENSE AGREEMENT FOR EVALUATION (this "Agreement") is a legal contract between a person
4
+ who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT").
5
+
6
+ READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR
7
+ OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS
8
+ AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER
9
+ THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE
10
+ SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER
11
+ UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY
12
+ TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD
13
+ TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR
14
+ USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE
15
+ ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE.
16
+
17
+
18
+ BACKGROUND
19
+ A. NTT is the owner of all rights, including all patent rights, copyrights and trade secret rights, in and to the Software and
20
+ related documentation except OSS listed in Exhibit A to this Agreement.
21
+
22
+ B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such
23
+ a license to User, pursuant and subject to the terms and conditions of this Agreement.
24
+
25
+ C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement.
26
+
27
+ In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows:
28
+
29
+ 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and
30
+ conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the
31
+ non-commercial purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in the research paper
32
+ submitted by NTT to a certain academy or technical contest, etc. ("academy"). User may make a reasonable number of
33
+ backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1.
34
+
35
+ 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User
36
+ shall be solely responsible for proper installation of the Software.
37
+
38
+ 3. Term. This Agreement is effective whichever is earlier (i) upon User's acceptance of the Agreement, or (ii) upon User's
39
+ installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to
40
+ any other rights, NTT may terminate this Agreement without notice to User (i) if User breaches or fails to comply with any
41
+ of the limitations or other requirements described herein, and (ii) if NTT receives a notice from the academy stating that the
42
+ research paper would not be published, and in any such case User agrees that NTT may, in addition to any other remedies
43
+ it may have at law or in equity, remotely disable the Software. User may terminate this Agreement at any time by User's
44
+ decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this
45
+ Agreement for any reason, User agrees to uninstall the Software and either return to NTT the Software and all copies thereof,
46
+ or to destroy all such materials and provide written verification of such destruction to NTT.
47
+
48
+ 4. Proprietary Rights
49
+ (a) The Software is the valuable, confidential, and proprietary property of NTT, and NTT shall retain exclusive title to
50
+ this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges
51
+ that all patent rights, copyrights and trade secret rights in the Software except OSS shall remain the exclusive property of
52
+ NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software.
53
+
54
+ (b) NTT shall not be subject to the obligation of licensing the copyright, patent rights, etc. of author when user hope
55
+ commercial / noncommercial use of the published / provided software, etc.
56
+
57
+ (c) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT:
58
+ (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY;
59
+ (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER;
60
+ (iii) DISCLOSE THE SOFTWARE TO ANY THIRD PARTY, EXCEPT TO USER'S EMPLOYEES WHO REQUIRE ACCESS TO THE SOFTWARE FOR THE PURPOSES OF THIS AGREEMENT;
61
+ (iv) MODIFY, DISASSEMBLE, DECOMPILE, REVERSE ENGINEER OR TRANSLATE THE SOFTWARE;
62
+ OR (v) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (iv) ABOVE.
63
+
64
+ (d) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted
65
+ under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied.
66
+
67
+ 5. Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage,
68
+ or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE
69
+ RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE.
70
+
71
+ 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ONTHE PART OF NTT.
72
+ NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY,
73
+ OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES.
74
+ USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS,
75
+ AND UTILITY IN A PRODUCTION ENVIRONMENT.
76
+
77
+ 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL,
78
+ OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS,
79
+ ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES
80
+ PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE.
81
+ THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT,
82
+ INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION
83
+ PURSUANT TO SECTION 3.
84
+
85
+ 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned,
86
+ or otherwise transferred by User without NTT's prior written consent.
87
+
88
+ 9. OSS. The OSS included in the software is shown on the "OSS List" in Exhibit A.
89
+ User shall be subject to the license term of each OSS, when User uses the software.
90
+
91
+ 10. General
92
+ (d) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by
93
+ operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this
94
+ Agreement shall remain in full force and effect.
95
+
96
+ (e) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the
97
+ subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the
98
+ parties relating to that subject matter.
99
+
100
+ (f) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and
101
+ assigns of NTT and User.
102
+
103
+ (g) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this
104
+ Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not
105
+ as damages, its attorneys' fees and other costs associated with such action or proceeding.
106
+
107
+ (h) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles.
108
+ All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo
109
+ in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association.
110
+ The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final
111
+ and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof.
112
+
113
+ (f) NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT's obligation set
114
+ forth under this Agreement due to any cause beyond NTT's reasonable control.
115
+
116
+ EXHIBIT A
117
+ - Software
118
+ N/A
119
+
120
+ - OSS List
121
+ +----+---------+-----+
122
+ | No | License | OSS |
123
+ +----+---------+-----+
124
+ | 1 | N/A | N/A |
125
+ +----+---------+-----+
README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RouteExplainer: An Explanation Framework for Vehicle Routing Problem
2
+ This repo is the official implementation of "RouteExplainer: An Explanation Framework for Vehicle Routing Problem" (PAKDD 2024). Please check more details at the project page https://ntt-dkiku.github.io/xai-vrp/.
3
+
4
+ ## Setup
5
+ We recommend using Docker to setup development environments. Please use the [Dockerfile](./Dockerfile) in this repository.
6
+ ```
7
+ docker build -t route_explainer/route_explainer:1.0 .
8
+ ```
9
+ If you use LKH and Concorde, you need to install them by typing the following command. LKH and Concorde is required for reproducing experiments, but not for demo.
10
+ ```
11
+ python install_solvers.py
12
+ ```
13
+ In the following, all commands are supposed to be typed inside the Docker container.
14
+
15
+ ## Reproducibility
16
+ <!-- Refer to [reproduce_experiments.ipynb](./reproduct_experiments.ipynb). -->
17
+ Coming Soon!
18
+
19
+ ## Training and evaluating edge classifiers
20
+ ### Generating synthetic data with labels
21
+ ```
22
+ python generate_dataset.py --problem tsptw --annotation --parallel
23
+ ```
24
+
25
+ ### Training
26
+ ```
27
+ python train.py
28
+ ```
29
+
30
+ ### Evaluation
31
+ ```
32
+ python eval.py
33
+ ```
34
+
35
+ ## Explanation generation (demo)
36
+ Go to http://localhost:8888 after launching the streamlit app with the following command. You may change the port number as you like.
37
+ ```
38
+ streamlit run app.py --server.port 8888
39
+ ```
40
+
41
+ ## Licence
42
+ Our code is licenced by NTT. Basically, the use of our code is limitted to research purposes. See [LICENSE](./LICENSE) for more details.
43
+
44
+ ## Citation
45
+ Coming Soon!
analyze_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import matplotlib.cm as cm
4
+ from utils.util_data import load_dataset
5
+
6
+
7
+ def get_cmap(num_colors):
8
+ if num_colors <= 10:
9
+ cm_name = "tab10"
10
+ elif num_colors <= 20:
11
+ cm_name = "tab20"
12
+ else:
13
+ assert False
14
+ return cm.get_cmap(cm_name)
15
+
16
+ def analyze_dataset(dataset_path, output_dir):
17
+ dataset = load_dataset(dataset_path)
18
+
19
+ #-----------------------------
20
+ # Stepwise frequency analysis
21
+ #-----------------------------
22
+ max_steps = len(dataset[0][0]) # num_nodes
23
+ num_labels = 2
24
+ freq = [[] for _ in range(num_labels)]
25
+ weights = [[] for _ in range(num_labels)]
26
+ for instance in dataset:
27
+ labels = instance[-1]
28
+ for step, label in labels:
29
+ freq[label].append(step)
30
+ # visualize histogram
31
+ fig = plt.figure(figsize=(10, 10))
32
+ binwidth = 1
33
+ bins = np.arange(0, max_steps + binwidth, binwidth)
34
+ cmap = get_cmap(num_labels)
35
+ for i in range(len(weights)):
36
+ weights[i] = np.ones(len(freq[i])) / len(dataset)
37
+ plt.hist(freq[i], bins=bins, alpha=0.5, weights=weights[i], ec=cmap(i), color=cmap(i), label="prioritizing tour length", align="left")
38
+ plt.xlabel("Steps")
39
+ plt.ylabel("Frequency (density)")
40
+ if max_steps <= 20:
41
+ plt.xticks(np.arange(0, max_steps+1, 1))
42
+ plt.title(f"# of samples = {len(dataset)}\n# of nodes = {max_steps}")
43
+ plt.legend()
44
+ plt.savefig(f"{output_dir}/hist.png", dpi=150, bbox_inches="tight")
45
+
46
+ #-----------------------------
47
+ # Overall ratio of each class
48
+ #-----------------------------
49
+ total = np.sum([len(freq[i]) for i in range(num_labels)])
50
+ ratio = np.array([len(freq[i]) for i in range(num_labels)])
51
+ ratio = ratio / total
52
+ with open(f"{output_dir}/ratio.dat", "w") as f:
53
+ for i in range(len(ratio)):
54
+ print(f"label{i}, {ratio[i]}", file=f)
55
+
56
+ if __name__ == "__main__":
57
+ import argparse
58
+ import os
59
+ parser = argparse.ArgumentParser(description='')
60
+ parser.add_argument("--dataset_path", type=str, required=True)
61
+ parser.add_argument("--output_dir", type=str, default=None)
62
+ args = parser.parse_args()
63
+
64
+ if args.output_dir is None:
65
+ dataset_dir = os.path.split(args.dataset_path)[0]
66
+ output_dir = dataset_dir
67
+ else:
68
+ output_dir = args.output_dir
69
+ output_dir += "/analysis"
70
+ os.makedirs(output_dir, exist_ok=True)
71
+ analyze_dataset(args.dataset_path, output_dir)
app.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # standard modules
2
+ import os
3
+ import pickle
4
+ import datetime
5
+ from PIL import Image
6
+ from typing import List, Union
7
+
8
+ # useful modules ("pip install" is required)
9
+ import numpy as np
10
+ import streamlit as st
11
+ import pandas as pd
12
+ import googlemaps
13
+ import langchain
14
+ from langchain.globals import set_verbose
15
+ from langchain.chat_models import ChatOpenAI
16
+ from langchain.schema import HumanMessage, AIMessage
17
+
18
+ # our defined modules
19
+ import utils.util_app as util_app
20
+ from models.solvers.general_solver import GeneralSolver
21
+ from models.cf_generator import CFTourGenerator
22
+ from models.classifiers.general_classifier import GeneralClassifier
23
+ from models.route_explainer import RouteExplainer
24
+
25
+ # general setting
26
+ SEED = 1234
27
+ TOUR_NAME = "static/kyoto_tour"
28
+ TOUR_PATH = TOUR_NAME + ".csv"
29
+ TOUR_LATLNG_PATH = TOUR_NAME + "_latlng.csv"
30
+ TOUR_DISTMAT_PATH = TOUR_NAME + "_distmat.pkl"
31
+ EXPANDED = False
32
+ DEBUG = True
33
+ ROUTE_EXPLAINER_ICON = np.array(Image.open("static/route_explainer_icon.png"))
34
+
35
+ # for debug
36
+ if DEBUG:
37
+ langchain.debug = True
38
+ set_verbose(True)
39
+
40
+ def load_tour_list():
41
+ # get lat/lng
42
+ if os.path.isfile(TOUR_LATLNG_PATH):
43
+ df_tour = pd.read_csv(TOUR_LATLNG_PATH)
44
+ else:
45
+ df_tour = pd.read_csv(TOUR_PATH)
46
+ if googleapi_key := st.session_state.googleapi_key:
47
+ gmaps = googlemaps.Client(key=googleapi_key)
48
+ lat_list =[]; lng_list = []
49
+ for destination in df_tour["destination"]:
50
+ geo_result = gmaps.geocode(destination)
51
+ lat_list.append(geo_result[0]["geometry"]["location"]["lat"])
52
+ lng_list.append(geo_result[0]["geometry"]["location"]["lng"])
53
+ # add lat/lng
54
+ df_tour["lat"] = lat_list
55
+ df_tour["lng"] = lng_list
56
+ df_tour.to_csv(TOUR_LATLNG_PATH)
57
+
58
+ # get the central point
59
+ st.session_state.lat_mean = np.mean(df_tour["lat"])
60
+ st.session_state.lng_mean = np.mean(df_tour["lng"])
61
+ st.session_state.sw = df_tour[["lat", "lng"]].min().tolist()
62
+ st.session_state.ne = df_tour[["lat", "lng"]].max().tolist()
63
+ st.session_state.df_tour = df_tour
64
+
65
+ # get the distance matrix
66
+ if os.path.isfile(TOUR_DISTMAT_PATH):
67
+ with open(TOUR_DISTMAT_PATH, "rb") as f:
68
+ distmat = pickle.load(f)
69
+ else:
70
+ if googleapi_key := st.session_state.googleapi_key:
71
+ gmaps = googlemaps.Client(key=googleapi_key)
72
+ distmat = []
73
+ for origin in df_tour["destination"]:
74
+ distrow = []
75
+ for dest in df_tour["destination"]:
76
+ if origin != dest:
77
+ dist_result = gmaps.distance_matrix(origin, dest, mode="driving")
78
+ distrow.append(dist_result["rows"][0]["elements"][0]["duration"]["value"]) # unit: seconds
79
+ else:
80
+ distrow.append(0)
81
+ distmat.append(distrow)
82
+ distmat = np.array(distmat)
83
+ with open(TOUR_DISTMAT_PATH, "wb") as f:
84
+ pickle.dump(distmat, f)
85
+
86
+ # input features
87
+ def convert_clock2seconds(clock):
88
+ return sum([a*b for a, b in zip([3600, 60], map(int, clock.split(':')))])
89
+ time_windows = []
90
+ for i in range(len(df_tour)):
91
+ time_windows.append([convert_clock2seconds(df_tour["open"][i]),
92
+ convert_clock2seconds(df_tour["close"][i])])
93
+ time_windows = np.array(time_windows)
94
+ time_windows -= time_windows[0, 0]
95
+ node_feats = {
96
+ "time_window": time_windows.clip(0),
97
+ "service_time": df_tour["stay_duration (h)"].to_numpy() * 3600
98
+ }
99
+ st.session_state.node_feats = node_feats
100
+ st.session_state.dist_matrix = distmat
101
+ st.session_state.node_info = {
102
+ "open": df_tour["open"],
103
+ "close": df_tour["close"],
104
+ "stay": df_tour["stay_duration (h)"]
105
+ }
106
+
107
+ # tour list
108
+ if os.path.isfile(TOUR_DISTMAT_PATH) & os.path.isfile(TOUR_LATLNG_PATH):
109
+ st.session_state.tour_list = []
110
+ for i in range(len(df_tour)):
111
+ st.session_state.tour_list.append({
112
+ "name": df_tour["destination"][i],
113
+ "latlng": (df_tour["lat"][i], df_tour["lng"][i]),
114
+ "description": f"<font color='silver'>Hours: {df_tour['open'][i]} - {df_tour['close'][i]}<br>Duration of stay: {df_tour['stay_duration (h)'][i]}h<br>Remarks: {df_tour['remarks'][i]}</font>"
115
+ })
116
+
117
+ def solve_vrp() -> None:
118
+ if ("node_feats" in st.session_state) and ("dist_matrix" in st.session_state):
119
+ solver = GeneralSolver("tsptw", "ortools", scaling=False)
120
+ classifier = GeneralClassifier("tsptw", "gt(ortools)")
121
+ routes = solver.solve(node_feats=st.session_state.node_feats,
122
+ dist_matrix=st.session_state.dist_matrix)
123
+ inputs = classifier.get_inputs(routes,
124
+ 0,
125
+ st.session_state.node_feats,
126
+ st.session_state.dist_matrix)
127
+ labels = classifier(inputs)
128
+ st.session_state.routes = routes.copy()
129
+ st.session_state.labels = labels.copy()
130
+ st.session_state.generated_actual_route = True
131
+
132
+ #----------
133
+ # LLM
134
+ #----------
135
+ def load_route_explainer(llm_type: str) -> None:
136
+ if st.session_state.openai_key:
137
+ # define llm
138
+ llm = ChatOpenAI(model=llm_type,
139
+ temperature=0,
140
+ streaming=True,
141
+ model_kwargs={"seed": SEED})
142
+ # model_kwargs={"stop": ["\n\n", "Human"]}
143
+
144
+ # define RouteExplainer
145
+ cf_generator = CFTourGenerator(cf_solver=GeneralSolver("tsptw", "ortools", scaling=False))
146
+ classifier = GeneralClassifier("tsptw", "gt(ortools)")
147
+ st.session_state.route_explainer = RouteExplainer(llm=llm,
148
+ cf_generator=cf_generator,
149
+ classifier=classifier)
150
+
151
+ #----------
152
+ # UI
153
+ #----------
154
+ # css settings
155
+ st.set_page_config(layout="wide")
156
+ util_app.apply_responsible_map_css()
157
+ util_app.apply_centerize_icon_css()
158
+ util_app.apply_red_code_css()
159
+ util_app.apply_remove_sidebar_topspace()
160
+
161
+ #------------------
162
+ # side bar setting
163
+ #------------------
164
+ with st.sidebar:
165
+ #-------
166
+ # Title
167
+ #-------
168
+ icon_col, name_col = st.columns((1,10))
169
+ with icon_col:
170
+ util_app.apply_html('<a href="https://ntt-dkiku.github.io/xai-vrp" target="_blank"><img src="./app/static/route_explainer_icon.png" alt="RouteExplainer" width="30" height="30" style="margin-top: 20px;"></a>')
171
+ with name_col:
172
+ st.title("RouteExplainer")
173
+
174
+ #----------
175
+ # API keys
176
+ #----------
177
+ st.subheader("API keys")
178
+ openai_key_col1, openai_key_col2 = st.columns((1,10))
179
+ with openai_key_col1:
180
+ util_app.apply_html('<a href="https://openai.com/blog/openai-api" target="_blank"> <img src="./app/static/openai_logo.png" alt="OpenAI API" width="30" height="30"> </a>')
181
+ with openai_key_col2:
182
+ openai_key = st.text_input(label="API keys",
183
+ key="openai_key",
184
+ placeholder="OpenAI API key",
185
+ type="password",
186
+ label_visibility="collapsed")
187
+ changed_key = openai_key == os.environ.get('OPENAI_API_KEY')
188
+ os.environ['OPENAI_API_KEY'] = openai_key
189
+
190
+ google_key_col1, google_key_col2 = st.columns((1, 10))
191
+ with google_key_col1:
192
+ util_app.apply_html('<a href="https://developers.google.com/maps?hl=en" target="_blank"> <img src="./app/static/googlemap_logo.png" alt="GoogleMap API" width="30" height="30"> </a>')
193
+ with google_key_col2:
194
+ st.text_input(label="GoogleMap API key",
195
+ key="googleapi_key",
196
+ placeholder="NOT required in this demo",
197
+ type="password",
198
+ label_visibility="collapsed")
199
+
200
+ #----------------
201
+ # Foundation LLM
202
+ #----------------
203
+ st.subheader("Foundation LLM")
204
+ llm_type = st.selectbox("LLM", ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo"], key="llm_type", label_visibility="collapsed")
205
+
206
+ #-----------
207
+ # Tour plan
208
+ #-----------
209
+ st.subheader("Tour plan")
210
+ col1, col2 = st.columns((2, 1))
211
+ with col1:
212
+ # Comming soon: "Taipei Tour (for PAKDD2024)"
213
+ tour_plan = st.selectbox("Tour plan", ["Kyoto Tour"], key="tour_type", label_visibility="collapsed")
214
+ with col2:
215
+ st.button("Generate", on_click=solve_vrp, use_container_width=True)
216
+
217
+ # list destinations
218
+ load_tour_list()
219
+ with st.container():
220
+ if "routes" in st.session_state: # rearranage destinations in the route order if a route was derivied
221
+ # re-ordered destinations
222
+ reordered_tour_list = [st.session_state.tour_list[i] for i in st.session_state.routes[0][:-1]] if "routes" in st.session_state else st.session_state.tour_list
223
+ arr_time = datetime.datetime.strptime(st.session_state.node_info["open"][0], "%H:%M")
224
+ for step in range(len(reordered_tour_list)):
225
+ curr = reordered_tour_list[step]
226
+ next = reordered_tour_list[step+1] if step != len(reordered_tour_list) - 1 else reordered_tour_list[0]
227
+ curr_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, curr["name"])
228
+ next_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, next["name"])
229
+ open_time = datetime.datetime.strptime(st.session_state.node_info["open"][curr_node_id], "%H:%M")
230
+ # destination info
231
+ dep_time = max(arr_time, open_time) + datetime.timedelta(hours=st.session_state.node_info["stay"][curr_node_id])
232
+ dep_time_str = dep_time.strftime("%H:%M")
233
+ arr_time_str = arr_time.strftime("%H:%M")
234
+ arr_dep = f"Arr {arr_time_str} - Dep {dep_time_str}" if step != 0 else f"⭐ Dep {dep_time_str}"
235
+ with st.expander(f"{arr_dep} | {curr['name']}", expanded=EXPANDED):
236
+ st.write(curr["description"], unsafe_allow_html=True)
237
+ # travel time
238
+ travel_time = st.session_state.dist_matrix[curr_node_id][next_node_id].item()
239
+ col1, col2, col3 = st.columns(3)
240
+ with col1:
241
+ st.markdown(f"<center>{util_app.add_time_unit(travel_time)}</center>", unsafe_allow_html=True)
242
+ with col2:
243
+ st.markdown("<center>|</center>", unsafe_allow_html=True)
244
+ st.write("")
245
+ arr_time = dep_time + datetime.timedelta(seconds=travel_time)
246
+ # return to the origin
247
+ destination = reordered_tour_list[0]
248
+ arr_time_str = arr_time.strftime("%H:%M")
249
+ with st.expander(f"⭐ Arr {arr_time_str} | {destination['name']}", expanded=EXPANDED):
250
+ st.write(destination["description"], unsafe_allow_html=True)
251
+ else: # just list destinations
252
+ for destination in st.session_state.tour_list:
253
+ with st.expander(destination['name'], expanded=EXPANDED):
254
+ st.write(destination["description"], unsafe_allow_html=True)
255
+
256
+ #----------------------
257
+ # state initialization
258
+ #----------------------
259
+ if "count" not in st.session_state:
260
+ st.session_state.count = 0
261
+ if "chat_history" not in st.session_state:
262
+ st.session_state.chat_history = []
263
+ if "generated_actual_route" not in st.session_state:
264
+ st.session_state.generated_actual_route = False
265
+ if "generated_cf_route" not in st.session_state:
266
+ st.session_state.generated_cf_route = False
267
+ if "curr_route" not in st.session_state:
268
+ st.session_state.curr_route = "Actual Route" # once the CF route is selected, this will be "Current Route"
269
+ if "flag_example" not in st.session_state:
270
+ st.session_state.flag_example = False
271
+ if "selected_example" not in st.session_state:
272
+ st.session_state.selected_example = None
273
+ if "close_chat" not in st.session_state:
274
+ st.session_state.close_chat = False
275
+ if "route_explainer" not in st.session_state or llm_type != st.session_state.curr_llm_type or changed_key:
276
+ load_route_explainer(llm_type)
277
+ st.session_state.curr_llm_type = llm_type
278
+
279
+ #--------------------------------
280
+ # The following is the main page
281
+ #--------------------------------
282
+
283
+ #----------
284
+ # Greeding
285
+ #----------
286
+ if "routes" not in st.session_state:
287
+ util_app.apply_html('<center> <img src="./app/static/route_explainer_icon.png" alt="OpenAI API" width="120" height="120"> </center>')
288
+ greeding = "Hi, I'm RouteExplainer :)<br>Choose a tour and hit the <code>Generate</code> button to generate your initial route!"
289
+ if st.session_state.count == 0:
290
+ util_app.stream_words(greeding, prefix="<center><h4>", suffix="</h4></center>", sleep_time=0.02)
291
+ else:
292
+ util_app.apply_html(f"<center><h4>{greeding}</h4></center>")
293
+
294
+ #--------------
295
+ # chat history
296
+ #--------------
297
+ def find_last_map(lst: List[Union[str, tuple]]) -> int:
298
+ for i in range(len(lst) - 1, -1, -1):
299
+ if isinstance(lst[i], tuple):
300
+ return i
301
+ return None
302
+ last_map_idx = find_last_map(st.session_state.chat_history)
303
+ for i, msg in enumerate(st.session_state.chat_history):
304
+ if isinstance(msg, tuple): # if the history type is a tuple of maps
305
+ map1, map2 = (0, 1) if i == last_map_idx else (2, 3)
306
+ actual_route, cf_route = st.columns(2)
307
+ if msg[map1] is not None:
308
+ with actual_route:
309
+ util_app.visualize_actual_route(msg[map1])
310
+ if msg[map2] is not None:
311
+ with cf_route:
312
+ util_app.visualize_cf_route(msg[map2])
313
+ else: # if the history type is string
314
+ if isinstance(msg, AIMessage):
315
+ st.chat_message(msg.type, avatar=ROUTE_EXPLAINER_ICON).write(msg.content)
316
+ else:
317
+ st.chat_message(msg.type).write(msg.content)
318
+
319
+ # examples
320
+ if "cf_routes" not in st.session_state and st.session_state.flag_example:
321
+ def pickup_example(example: str):
322
+ st.session_state.selected_example = example
323
+
324
+ examples = [
325
+ "Why do we visit Ginkaku-ji Temple from Fushimi-Inari Shrine and why not Kiyomizu-dera Temple?",
326
+ "What if we visit Kinkaku-ji directly from Kyoto Geishinkan, instead of Nijo-jo Castle?",
327
+ "Why was the edge from Kinkaku-ji to Kiyomizu-dera selected and why not the edge from Kinkaku-ji to Hanamikoji Dori?"
328
+ ]
329
+ col1, col2, col3 = st.columns(3)
330
+ with col1:
331
+ st.button(examples[0],
332
+ use_container_width=True,
333
+ on_click=pickup_example,
334
+ args=(examples[0], ))
335
+ with col2:
336
+ st.button(examples[1],
337
+ use_container_width=True,
338
+ on_click=pickup_example,
339
+ args=(examples[1], ))
340
+ with col3:
341
+ st.button(examples[2],
342
+ use_container_width=True,
343
+ on_click=pickup_example,
344
+ args=(examples[2], ))
345
+
346
+ #----------
347
+ # chat box
348
+ #----------
349
+ def answer(prompt: str):
350
+ st.session_state.chat_history.append(HumanMessage(content=prompt))
351
+ st.chat_message("user").write(prompt)
352
+ if os.environ.get('OPENAI_API_KEY') == "":
353
+ error_msg = "An OpenAI API key has not been set yet :( Please enter a valid key in the side bar!"
354
+ with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
355
+ st.write(error_msg)
356
+ st.session_state.chat_history.append(AIMessage(content=error_msg))
357
+ elif util_app.validate_openai_api_key(os.environ.get('OPENAI_API_KEY')):
358
+ with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
359
+ explanation = st.session_state.route_explainer.generate_explanation(tour_list=st.session_state.tour_list,
360
+ whynot_question=prompt,
361
+ actual_routes=st.session_state.routes,
362
+ actual_labels=st.session_state.labels,
363
+ node_feats=st.session_state.node_feats,
364
+ dist_matrix=st.session_state.dist_matrix)
365
+ if len(explanation) > 0:
366
+ st.session_state.chat_history.append(AIMessage(content=explanation))
367
+ if st.session_state.generated_cf_route:
368
+ st.rerun()
369
+ else:
370
+ error_msg = "The input OpenAI API key appears to be invalid :( Please enter a valid key again in the side bar!"
371
+ with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
372
+ st.write(error_msg)
373
+ st.session_state.chat_history.append(AIMessage(content=error_msg))
374
+
375
+ if st.session_state.selected_example is not None:
376
+ example = st.session_state.selected_example
377
+ st.session_state.selected_example = None
378
+ answer(example)
379
+ else:
380
+ if "routes" in st.session_state and not st.session_state.close_chat:
381
+ if prompt := st.chat_input(placeholder="Ask a why-not question", key="chat_input"):
382
+ answer(prompt)
383
+
384
+ #---------------------
385
+ # route visualization
386
+ #---------------------
387
+ if "tour_list" in st.session_state: # if tour info is loaded
388
+ # first message
389
+ if st.session_state.generated_actual_route: # when an actual route is generated
390
+ with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
391
+ msg = "Here is your initial route. Please ask me a why and why-not question for a specfic edge!"
392
+ util_app.stream_words(msg, sleep_time=0.01)
393
+ st.session_state.flag_example = True
394
+ st.session_state.chat_history.append(AIMessage(content=msg))
395
+
396
+ # visualize the actual & CF routes
397
+ actual_route, cf_route = st.columns(2)
398
+ m = None; m2 = None; m_ = None; m2_ = None
399
+ if st.session_state.generated_actual_route or st.session_state.generated_cf_route:
400
+ m = util_app.initialize_map() # overwrite m
401
+ m_ = util_app.initialize_map()
402
+ if "labels" in st.session_state:
403
+ cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1
404
+ util_app.vis_route("routes", st.session_state.labels, m, cf_step, "actual")
405
+ util_app.vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False)
406
+ with actual_route:
407
+ util_app.visualize_actual_route(m)
408
+ if st.session_state.generated_cf_route:
409
+ m2 = util_app.initialize_map() # overwrite m2
410
+ m2_ = util_app.initialize_map()
411
+ if "cf_labels" in st.session_state:
412
+ util_app.vis_route("cf_routes", st.session_state.cf_labels, m2, st.session_state.cf_step-1, "cf")
413
+ util_app.vis_route("cf_routes", st.session_state.cf_labels, m2_, st.session_state.cf_step-1, "cf", ant_path=False)
414
+ with cf_route:
415
+ util_app.visualize_cf_route(m2)
416
+
417
+ # update states related to maps
418
+ if m is not None:
419
+ st.session_state.chat_history.append((m, m2, m_, m2_))
420
+
421
+ # route selection button
422
+ if len(st.session_state.chat_history) > 0:
423
+ last_msg = st.session_state.chat_history[-1]
424
+ if isinstance(last_msg, tuple):
425
+ if (last_msg[0] is not None) and (last_msg[1] is not None):
426
+ col1, col2 = st.columns(2)
427
+ with col1:
428
+ st.button("Stay this route", on_click=util_app.select_actual_route)
429
+ with col2:
430
+ st.button("Replace with this route", on_click=util_app.select_cf_route)
431
+ util_app.change_hover_color("button", "Replace with this route", "#1e90ff")
432
+
433
+ # for displaying examples
434
+ if st.session_state.generated_actual_route:
435
+ st.session_state.generated_actual_route = False
436
+ st.rerun()
437
+
438
+ st.session_state.generated_actual_route = False
439
+ st.session_state.generated_cf_route = False
440
+
441
+ # update session count
442
+ st.session_state.count += 1
443
+
444
+ js = f"""
445
+ <script>
446
+ function scroll(dummy_var_to_force_repeat_execution){{
447
+ var textAreas = parent.document.querySelectorAll('section.main');
448
+ for (let index = 0; index < textAreas.length; index++) {{
449
+ textAreas[index].scrollTop = textAreas[index].scrollHeight;
450
+ }}
451
+ }}
452
+ scroll({len(st.session_state.chat_history)})
453
+ </script>
454
+ """
455
+ st.components.v1.html(js, height=0, width=0)
eval_classifier.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ import multiprocessing
5
+ import torch
6
+ import time
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
10
+ from utils.util_calc import TemporalConfusionMatrix
11
+ from models.classifiers.nn_classifiers.nn_classifier import NNClassifier
12
+ from models.classifiers.ground_truth.ground_truth import GroundTruth
13
+ from models.classifiers.ground_truth.ground_truth_base import FAIL_FLAG
14
+ from utils.data_utils.tsptw_dataset import TSPTWDataloader
15
+ from utils.data_utils.pctsp_dataset import PCTSPDataloader
16
+ from utils.data_utils.pctsptw_dataset import PCTSPTWDataloader
17
+ from utils.data_utils.cvrp_dataset import CVRPDataloader
18
+ from utils.utils import set_device
19
+ from utils.utils import load_dataset
20
+
21
+ def load_eval_dataset(dataset_path, problem, model_type, batch_size, num_workers, parallel, num_cpus):
22
+ if model_type == "nn":
23
+ if problem == "tsptw":
24
+ eval_dataset = TSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
25
+ elif problem == "pctsp":
26
+ eval_dataset = PCTSPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
27
+ elif problem == "pctsptw":
28
+ eval_dataset = PCTSPTWDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
29
+ elif problem == "cvrp":
30
+ eval_dataset = CVRPDataloader(dataset_path, sequential=True, parallel=parallel, num_cpus=num_cpus)
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ #------------
35
+ # dataloader
36
+ #------------
37
+ def pad_seq_length(batch):
38
+ data = {}
39
+ for key in batch[0].keys():
40
+ padding_value = True if key == "mask" else 0.0
41
+ # post-padding
42
+ data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
43
+ pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
44
+ data.update({"pad_mask": pad_mask})
45
+ return data
46
+ eval_dataloader = DataLoader(eval_dataset,
47
+ batch_size=batch_size,
48
+ shuffle=False,
49
+ collate_fn=pad_seq_length,
50
+ num_workers=num_workers)
51
+ return eval_dataloader
52
+ else:
53
+ eval_dataset = load_dataset(dataset_path)
54
+ return eval_dataset
55
+
56
+ def eval_classifier(problem: str,
57
+ dataset,
58
+ model_type: str,
59
+ model_dir: str = None,
60
+ gpu: int = -1,
61
+ num_workers: int = 4,
62
+ batch_size: int = 128,
63
+ parallel: bool = True,
64
+ solver: str = "ortools",
65
+ num_cpus: int = 1):
66
+ #--------------
67
+ # gpu settings
68
+ #--------------
69
+ use_cuda, device = set_device(gpu)
70
+
71
+ #-------
72
+ # model
73
+ #-------
74
+ num_classes = 3 if problem == "pctsptw" else 2
75
+ if model_type == "nn":
76
+ assert model_dir is not None, "please specify model_path when model_type is nn."
77
+ params = argparse.ArgumentParser()
78
+ # model_dir = os.path.split(args.model_path)[0]
79
+ with open(f"{model_dir}/cmd_args.dat", "r") as f:
80
+ params.__dict__ = json.load(f)
81
+ assert params.problem == problem, "problem of the trained model should match that of the dataset"
82
+ model = NNClassifier(problem=params.problem,
83
+ node_enc_type=params.node_enc_type,
84
+ edge_enc_type=params.edge_enc_type,
85
+ dec_type=params.dec_type,
86
+ emb_dim=params.emb_dim,
87
+ num_enc_mlp_layers=params.num_enc_mlp_layers,
88
+ num_dec_mlp_layers=params.num_dec_mlp_layers,
89
+ num_classes=num_classes,
90
+ dropout=params.dropout,
91
+ pos_encoder=params.pos_encoder)
92
+ # load trained weights (the best epoch)
93
+ with open(f"{model_dir}/best_epoch.dat", "r") as f:
94
+ best_epoch = int(f.read())
95
+ print(f"loaded {model_dir}/model_epoch{best_epoch}.pth.")
96
+ model.load_state_dict(torch.load(f"{model_dir}/model_epoch{best_epoch}.pth"))
97
+ if use_cuda:
98
+ model.to(device)
99
+ is_sequential = model.is_sequential
100
+ elif model_type == "ground_truth":
101
+ model = GroundTruth(problem=problem, solver_type=solver)
102
+ is_sequential = False
103
+ else:
104
+ assert False, f"Invalid model type: {model_type}"
105
+
106
+ #---------
107
+ # Metrics
108
+ #---------
109
+ overall_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
110
+ eval_accuracy_dict = {} # MulticlassAccuracy(num_classes=num_classes, average="macro")
111
+ temp_confmat_dict = {} # TemporalConfusionMatrix(num_classes=num_classes, seq_length=50, device=device)
112
+ temporal_accuracy_dict = {}
113
+ num_nodes_dist_dict = {}
114
+
115
+ #------------
116
+ # Evaluation
117
+ #------------
118
+ if model_type == "nn":
119
+ model.eval()
120
+ eval_time = 0.0
121
+ print("Evaluating models ...", end="")
122
+ start_time = time.perf_counter()
123
+ for data in dataset:
124
+ if use_cuda:
125
+ data = {key: value.to(device) for key, value in data.items()}
126
+ if not is_sequential:
127
+ shp = data["curr_node_id"].size()
128
+ data = {key: value.flatten(0, 1) for key, value in data.items()}
129
+ probs = model(data) # [batch_size x num_classes] or [batch_size x max_seq_length x num_classes]
130
+ if not is_sequential:
131
+ probs = probs.view(*shp, -1) # [batch_size x max_seq_length x num_classes]
132
+ data["labels"] = data["labels"].view(*shp)
133
+ data["pad_mask"] = data["pad_mask"].view(*shp)
134
+ #------------
135
+ # evaluation
136
+ #------------
137
+ start_eval_time = time.perf_counter()
138
+ # accuracy
139
+ seq_length_list = torch.unique(data["pad_mask"].sum(-1))
140
+ for seq_length_tensor in seq_length_list:
141
+ seq_length = seq_length_tensor.item()
142
+ if seq_length not in eval_accuracy_dict.keys():
143
+ eval_accuracy_dict[seq_length] = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
144
+ temp_confmat_dict[seq_length] = TemporalConfusionMatrix(num_classes=num_classes, seq_length=seq_length, device=device)
145
+ temporal_accuracy_dict[seq_length] = [MulticlassF1Score(num_classes=num_classes, average="macro").to(device) for _ in range(seq_length)]
146
+ num_nodes_dist_dict[seq_length] = 0
147
+ seq_length_mask = (data["pad_mask"].sum(-1) == seq_length) # [batch_size]
148
+ extracted_labels = data["labels"][seq_length_mask]
149
+ extracted_probs = probs[seq_length_mask]
150
+ extracted_mask = data["pad_mask"][seq_length_mask].view(-1) # [batch_size x max_seq_length] -> [(batch_size*max_seq_length)]
151
+ eval_accuracy_dict[seq_length](extracted_probs.argmax(-1).view(-1)[extracted_mask], extracted_labels.view(-1)[extracted_mask])
152
+ mask = data["pad_mask"].view(-1)
153
+ overall_accuracy(probs.argmax(-1).view(-1)[mask], data["labels"].view(-1)[mask])
154
+ # confusion matrix
155
+ temp_confmat_dict[seq_length].update(probs.argmax(-1), data["labels"], data["pad_mask"])
156
+ # temporal accuracy
157
+ for step in range(seq_length):
158
+ temporal_accuracy_dict[seq_length][step](extracted_probs[:, step, :], extracted_labels[:, step])
159
+ # number of samples whose sequence length is seq_length
160
+ num_nodes_dist_dict[seq_length] += len(extracted_labels)
161
+ eval_time += time.perf_counter() - start_eval_time
162
+ calc_time = time.perf_counter() - start_time - eval_time
163
+ total_eval_accuracy = {key: value.compute().item() for key, value in eval_accuracy_dict.items()}
164
+ overall_accuracy = overall_accuracy.compute() #.item()
165
+ temporal_confmat = {key: value.compute() for key, value in temp_confmat_dict.items()}
166
+ temporal_accuracy = {key: [value.compute().item() for value in values] for key, values in temporal_accuracy_dict.items()}
167
+ print("done")
168
+ return overall_accuracy, total_eval_accuracy, temporal_accuracy, calc_time, temporal_confmat, num_nodes_dist_dict
169
+ else:
170
+ eval_accuracy = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
171
+ print("Loading data ...", end=" ")
172
+ with multiprocessing.Pool(num_cpus) as pool:
173
+ input_list = list(pool.starmap(model.get_inputs, [(instance["tour"], 0, instance) for instance in dataset]))
174
+ print("done")
175
+
176
+ print("Infering labels ...", end="")
177
+ pool = multiprocessing.Pool(num_cpus)
178
+ start_time = time.perf_counter()
179
+ prob_list = list(pool.starmap(model, tqdm([(inputs, False, False) for inputs in input_list])))
180
+ calc_time = time.perf_counter() - start_time
181
+ pool.close()
182
+ print("done")
183
+
184
+ print("Evaluating models ...", end="")
185
+ for i, instance in enumerate(dataset):
186
+ labels = instance["labels"]
187
+ for vehicle_id in range(len(labels)):
188
+ for step, label in labels[vehicle_id]:
189
+ pred_label = prob_list[i][vehicle_id][step-1] # [num_classes]
190
+ if pred_label == FAIL_FLAG:
191
+ pred_label = label - 1 if label != 0 else label + 1
192
+ eval_accuracy(torch.LongTensor([pred_label]).view(1, -1), torch.LongTensor([label]).view(1, -1))
193
+ total_eval_accuracy = eval_accuracy.compute()
194
+ print("done")
195
+ return total_eval_accuracy.item(), calc_time
196
+
197
+ if __name__ == "__main__":
198
+ parser = argparse.ArgumentParser()
199
+ #-----------------
200
+ # general settings
201
+ #-----------------
202
+ parser.add_argument("--gpu", default=-1, type=int, help="Used GPU Number: gpu=-1 indicates using cpu")
203
+ parser.add_argument("--num_workers", default=4, type=int, help="Number of workers in dataloader")
204
+ parser.add_argument("--parallel", )
205
+
206
+ #-------------
207
+ # data setting
208
+ #-------------
209
+ parser.add_argument("--dataset_path", type=str, help="Path to a dataset", required=True)
210
+
211
+ #------------------
212
+ # Metrics settings
213
+ #------------------
214
+
215
+
216
+ #----------------
217
+ # model settings
218
+ #----------------
219
+ parser.add_argument("--model_type", type=str, default="nn", help="Select from [nn, ground_truth]")
220
+ # nn classifier
221
+ parser.add_argument("--model_dir", type=str, default=None)
222
+ parser.add_argument("--batch_size", type=int, default=256)
223
+ parser.add_argument("--parallel", action="store_true")
224
+ # ground truth
225
+ parser.add_argument("--solver", type=str, default="ortools")
226
+ parser.add_argument("--num_cpus", type=int, default=os.cpu_count())
227
+ args = parser.parse_args()
228
+
229
+ problem = str(os.path.basename(os.path.dirname(args.dataset_path)))
230
+
231
+ dataset = load_eval_dataset(args.dataset_path, problem, args.model_type, args.batch_size, args.num_workers, args.parallel, args.num_cpus)
232
+ eval_classifier(problem=problem,
233
+ dataset=dataset,
234
+ model_type=args.model_type,
235
+ model_dir=args.model_dir,
236
+ gpu=args.gpu,
237
+ num_workers=args.num_workers,
238
+ batch_size=args.batch_size,
239
+ parallel=args.parallel,
240
+ solver=args.solver,
241
+ num_cpus=args.num_cpus)
eval_solvers.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import multiprocessing
4
+ import numpy as np
5
+ from utils.utils import load_dataset, calc_tour_length
6
+ from models.solvers.general_solver import GeneralSolver
7
+ from models.classifiers.ground_truth.ground_truth import GroundTruth
8
+
9
+ def eval_solver(solver, instance):
10
+ tour = solver.solve(instance)
11
+ tour_length = calc_tour_length(tour[0], instance["coords"])
12
+ return tour_length
13
+
14
+ def eval(data_path, problem, solver_name, fix_edges, parallel):
15
+ dataset = load_dataset(data_path)
16
+ num_cpus = os.cpu_count() if parallel else 1
17
+ if fix_edges:
18
+ solver = GroundTruth(problem, solver_name)
19
+ if parallel:
20
+ with multiprocessing.Pool(num_cpus) as pool:
21
+ tours = list(tqdm(pool.starmap(solver.solve, [(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}")
22
+ for i, instance in enumerate(dataset)
23
+ for vehicle_id in range(len(instance["tour"]))
24
+ for step in range(1, len(instance["tour"][vehicle_id]))]), desc=f"Solving {data_path} with {solver_name}"))
25
+ else:
26
+ tours = []
27
+ for i, instance in enumerate(dataset):
28
+ for vehicle_id in range(len(instance["tour"])):
29
+ for step in range(1, len(instance["tour"][vehicle_id])):
30
+ tours.append(solver.solve(step, instance["tour"][vehicle_id], instance, f"{i}-{vehicle_id}-{step}"))
31
+ tour_length = {key: [] for key in tours[0].keys()}
32
+ for tour in tours:
33
+ for key, value in tour.items():
34
+ tour_length[key].append(value)
35
+ else:
36
+ solver = GeneralSolver(problem, solver_name)
37
+ with multiprocessing.Pool(num_cpus) as pool:
38
+ tour_length = list(tqdm(pool.starmap(eval_solver, [(solver, instance) for instance in dataset]), total=len(dataset), desc="Solving instances"))
39
+
40
+ feasible_ratio = 0.0
41
+ penalty = 0.0
42
+ avg_tour_length = np.mean(tour_length["tsp"])
43
+ std_tour_length = np.std(tour_length["tsp"])
44
+ return avg_tour_length, std_tour_length, feasible_ratio, penalty
45
+
46
+ if __name__ == "__main__":
47
+ import argparse
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--problem", default="tsptw", type=str, help="Problem type: [tsptw, pctsp, pctsptw, cvrp]")
50
+ parser.add_argument("--solver_name", type=str, default="ortools", help="Select from ")
51
+ parser.add_argument("--data_path", type=str, help="Path to a dataset", required=True)
52
+ parser.add_argument("--parallel", action="store_true")
53
+ parser.add_argument("--all", action="store_true")
54
+ parser.add_argument("--fix_edges", action="store_true")
55
+ args = parser.parse_args()
56
+
57
+ avg_tour_length, std_tour_length, feasible_ratio, penalty = eval(data_path=args.data_path,
58
+ problem=args.problem,
59
+ solver_name=args.solver_name,
60
+ fix_edges=args.fix_edges,
61
+ parallel=args.parallel)
62
+ print(f"tour_length: {avg_tour_length} +/- {std_tour_length}")
generate_cf_dataset.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from multiprocessing import Pool
5
+ from utils.utils import load_dataset, save_dataset
6
+ from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask, get_cap_mask
7
+ from models.classifiers.ground_truth.ground_truth import GroundTruth
8
+ from models.solvers.general_solver import GeneralSolver
9
+ from models.cf_generator import CFTourGenerator
10
+
11
+
12
+ class CFDatasetBase():
13
+ def __init__(self, problem, cf_generator, classifier, base_dataset, num_samples, random_seed, parallel, num_cpus):
14
+ self.problem = problem
15
+ self.parallel = parallel
16
+ self.num_cpus = num_cpus
17
+ self.seed = random_seed
18
+ self.cf_generator = CFTourGenerator(cf_solver=GeneralSolver(problem, cf_generator))
19
+ self.classifier = GroundTruth(problem, classifier)
20
+ self.node_mask = NodeMask(problem)
21
+ self.dataset = load_dataset(base_dataset)
22
+ self.num_samples = len(self.dataset) if num_samples is None else num_samples
23
+
24
+ def generate_cf_dataset(self):
25
+ random.seed(self.seed)
26
+ cf_dataset = []
27
+ num_required_samples = self.num_samples
28
+ end = False
29
+ print("Data generation started.", flush=True)
30
+ while(not end):
31
+ dataset = self.dataset[:num_required_samples]
32
+ self.dataset = np.roll(self.dataset, -num_required_samples)
33
+ if self.parallel:
34
+ instances = self.generate_labeldata_para(dataset, self.num_cpus)
35
+ else:
36
+ instances = self.generate_labeldata(dataset)
37
+ cf_dataset.extend(filter(None, instances))
38
+ num_required_samples = self.num_samples - len(cf_dataset)
39
+ if num_required_samples == 0:
40
+ end = True
41
+ else:
42
+ print(f"No feasible tour was not found in {num_required_samples} instances. Trying other {num_required_samples} instances.", flush=True)
43
+ print("Data generation completed.", flush=True)
44
+ return cf_dataset
45
+
46
+ def generate_labeldata(self, dataset):
47
+ return [self.annotate(instance) for instance in tqdm(dataset, desc="Annotating instances")]
48
+
49
+ def generate_labeldata_para(self, dataset, num_cpus):
50
+ with Pool(num_cpus) as pool:
51
+ annotation_data = list(tqdm(pool.imap(self.annotate, [instance for instance in dataset]), total=len(dataset), desc="Annotating instances"))
52
+ return annotation_data
53
+
54
+ def annotate(self, instance):
55
+ # generate a counterfactual route randomly
56
+ routes = instance["tour"]
57
+ vehicle_id = random.randint(0, len(routes) - 1)
58
+ if len(routes[vehicle_id]) - 2 <= 2:
59
+ return
60
+ cf_step = random.randint(2, len(routes[vehicle_id]) - 2)
61
+ route = routes[vehicle_id]
62
+ mask = self.node_mask.get_mask(route, cf_step, instance)
63
+ node_id = np.arange(len(instance["coords"]))
64
+ feasible_node_id = node_id[mask]
65
+ feasible_node_id = feasible_node_id[feasible_node_id != route[cf_step]].tolist()
66
+ if len(feasible_node_id) == 0:
67
+ return
68
+ cf_visit = random.choice(feasible_node_id)
69
+ cf_routes = self.cf_generator(routes, vehicle_id, cf_step, cf_visit, instance)
70
+ if cf_routes is None:
71
+ return
72
+
73
+ # annotate each edge
74
+ inputs = self.classifier.get_inputs(cf_routes, 0, instance)
75
+ labels = self.classifier(inputs, annotation=True)
76
+
77
+ # update tours and lables
78
+ instance["tour"] = cf_routes
79
+ instance["labels"] = labels
80
+ return instance
81
+
82
+ class NodeMask():
83
+ def __init__(self, problem):
84
+ self.problem = problem
85
+
86
+ if self.problem == "tsptw":
87
+ self.mask_func = get_tsptw_mask
88
+ elif self.problem == "pctsp":
89
+ self.mask_func = get_pctsp_mask
90
+ elif self.problem == "pctsptw":
91
+ self.mask_func = get_pctsptw_mask
92
+ elif self.problem == "cvrp":
93
+ self.mask_func = get_cvrp_mask
94
+ else:
95
+ NotImplementedError
96
+
97
+ def get_mask(self, route, step, instance):
98
+ return self.mask_func(route, step, instance)
99
+
100
+ def get_tsptw_mask(route, step, instance):
101
+ visited = get_visited_mask(route, step, instance)
102
+ not_exceed_tw = get_tw_mask(route, step, instance)
103
+ return ~visited & not_exceed_tw
104
+
105
+ def get_pctsp_mask(route, step, instance):
106
+ visited = get_visited_mask(route, step, instance)
107
+ return ~visited
108
+
109
+ def get_pctsptw_mask(route, step, instance):
110
+ visited = get_visited_mask(route, step, instance)
111
+ not_exceed_tw = get_tw_mask(route, step, instance)
112
+ return ~visited & not_exceed_tw
113
+
114
+ def get_cvrp_mask(route, step, instance):
115
+ visited = get_visited_mask(route, step, instance)
116
+ less_than_cap = get_cap_mask(route, step, instance)
117
+ return ~visited & less_than_cap
118
+
119
+ if __name__ == "__main__":
120
+ import os
121
+ import argparse
122
+ parser = argparse.ArgumentParser(description='')
123
+ parser.add_argument("--problem", type=str, default="tsptw")
124
+ parser.add_argument("--base_dataset", type=str, required=True)
125
+ parser.add_argument("--cf_generator", type=str, default="ortools")
126
+ parser.add_argument("--classifier", type=str, default="ortools")
127
+ parser.add_argument("--num_samples", type=int, default=None)
128
+ parser.add_argument("--random_seed", type=int, default=1234)
129
+ parser.add_argument("--parallel", action="store_true")
130
+ parser.add_argument("--num_cpus", type=int, default=4)
131
+ parser.add_argument("--output_dir", type=str, default="data")
132
+ args = parser.parse_args()
133
+
134
+ dataset_gen = CFDatasetBase(args.problem,
135
+ args.cf_generator,
136
+ args.classifier,
137
+ args.base_dataset,
138
+ args.num_samples,
139
+ args.random_seed,
140
+ args.parallel,
141
+ args.num_cpus)
142
+ cf_dataset = dataset_gen.generate_cf_dataset()
143
+
144
+ output_fname = f"{args.output_dir}/{args.problem}/cf_{dataset_gen.num_samples}samples_seed{args.random_seed}_base_{os.path.basename(args.base_dataset)}.pkl"
145
+ save_dataset(cf_dataset, output_fname)
generate_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.data_utils.tsptw_dataset import TSPTWDataset
2
+ from utils.data_utils.pctsp_dataset import PCTSPDataset
3
+ from utils.data_utils.pctsptw_dataset import PCTSPTWDataset
4
+ from utils.data_utils.cvrp_dataset import CVRPDataset
5
+ from utils.data_utils.cvrptw_dataset import CVRPTWDataset
6
+ from utils.utils import save_dataset
7
+
8
+ def generate_dataset(num_samples, args):
9
+ if args.problem == "tsptw":
10
+ data_generator = TSPTWDataset(coord_dim=args.coord_dim,
11
+ num_samples=num_samples,
12
+ num_nodes=args.num_nodes,
13
+ random_seed=args.random_seed,
14
+ solver=args.solver,
15
+ classifier=args.classifier,
16
+ annotation=args.annotation,
17
+ parallel=args.parallel,
18
+ num_cpus=args.num_cpus,
19
+ distribution=args.distribution)
20
+ elif args.problem == "pctsp":
21
+ data_generator = PCTSPDataset(coord_dim=args.coord_dim,
22
+ num_samples=num_samples,
23
+ num_nodes=args.num_nodes,
24
+ random_seed=args.random_seed,
25
+ solver=args.solver,
26
+ classifier=args.classifier,
27
+ annotation=args.annotation,
28
+ parallel=args.parallel,
29
+ num_cpus=args.num_cpus,
30
+ penalty_factor=args.penalty_factor)
31
+ elif args.problem == "pctsptw":
32
+ data_generator = PCTSPTWDataset(coord_dim=args.coord_dim,
33
+ num_samples=num_samples,
34
+ num_nodes=args.num_nodes,
35
+ random_seed=args.random_seed,
36
+ solver=args.solver,
37
+ classifier=args.classifier,
38
+ annotation=args.annotation,
39
+ parallel=args.parallel,
40
+ num_cpus=args.num_cpus,
41
+ penalty_factor=args.penalty_factor)
42
+ elif args.problem == "cvrp":
43
+ data_generator = CVRPDataset(coord_dim=args.coord_dim,
44
+ num_samples=num_samples,
45
+ num_nodes=args.num_nodes,
46
+ random_seed=args.random_seed,
47
+ solver=args.solver,
48
+ classifier=args.classifier,
49
+ annotation=args.annotation,
50
+ parallel=args.parallel,
51
+ num_cpus=args.num_cpus)
52
+ elif args.problem == "cvrptw":
53
+ data_generator = CVRPTWDataset(coord_dim=args.coord_dim,
54
+ num_samples=num_samples,
55
+ num_nodes=args.num_nodes,
56
+ random_seed=args.random_seed,
57
+ solver=args.solver,
58
+ classifier=args.classifier,
59
+ annotation=args.annotation,
60
+ parallel=args.parallel,
61
+ num_cpus=args.num_cpus)
62
+ else:
63
+ raise NotImplementedError
64
+
65
+ return data_generator.generate_dataset()
66
+
67
+ if __name__ == "__main__":
68
+ import argparse
69
+ import os
70
+ import numpy as np
71
+ parser = argparse.ArgumentParser(description='')
72
+ # common settings
73
+ parser.add_argument("--problem", type=str, default="tsptw")
74
+ parser.add_argument("--random_seed", type=int, default=1234)
75
+ parser.add_argument("--data_type", type=str, nargs="*", default=["all"], help="data type: 'all' or combo. of ['train', 'valid', 'test'].")
76
+ parser.add_argument("--num_samples", type=int, nargs="*", default=[1000, 100, 100])
77
+ parser.add_argument("--num_nodes", type=int, default=20)
78
+ parser.add_argument("--coord_dim", type=int, default=2, help="only coord_dim=2 is supported for now.")
79
+ parser.add_argument("--solver", type=str, default="ortools", help="solver that outputs a tour")
80
+ parser.add_argument("--classifier", type=str, default="ortools", help="classifier for annotation")
81
+ parser.add_argument("--annotation", action="store_true")
82
+ parser.add_argument("--parallel", action="store_true")
83
+ parser.add_argument("--num_cpus", type=int, default=os.cpu_count())
84
+ parser.add_argument("--output_dir", type=str, default="data")
85
+ # for TSPTW
86
+ parser.add_argument("--distribution", type=str, default="da_silva")
87
+ # for PCTSP
88
+ parser.add_argument("--penalty_factor", type=float, default=3.)
89
+ args = parser.parse_args()
90
+
91
+ # 3d problems are not supported
92
+ assert args.coord_dim == 2, "only coord_dim=2 is supported for now."
93
+
94
+ # calc num. of total samples (train + valid + test samples)
95
+ if args.data_type[0] == "all":
96
+ assert len(args.num_samples) == 3, "please specify # samples for each of the three types (train/valid/test) when you set data_type 'all'. (e.g., --num_samples 1280000 1000 1000)"
97
+ else:
98
+ assert len(args.data_type) == len(args.num_samples), "please match # data_types and # elements in num_samples-arg"
99
+ num_samples = np.sum(args.num_samples)
100
+
101
+ # generate a dataset
102
+ dataset = generate_dataset(num_samples, args)
103
+
104
+ # split the dataset
105
+ if args.data_type[0] == "all":
106
+ types = ["train", "valid", "eval"]
107
+ else:
108
+ types = args.data_type
109
+ num_sample_list = args.num_samples
110
+ num_sample_list.insert(0, 0)
111
+ start = 0
112
+ for i, type_name in enumerate(types):
113
+ start += num_sample_list[i]
114
+ end = start + num_sample_list[i+1]
115
+ divided_datset = dataset[start:end]
116
+ output_fname = f"{args.output_dir}/{args.problem}/{type_name}_{args.problem}_{args.num_nodes}nodes_{num_sample_list[i+1]}samples_seed{args.random_seed}.pkl"
117
+ save_dataset(divided_datset, output_fname)
install_solvers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import subprocess
4
+
5
+ def _run(cmd, cwd):
6
+ subprocess.check_call(cmd, shell=True, cwd=cwd)
7
+
8
+ def install_concorde():
9
+ QSOPT_A_URL = "https://www.math.uwaterloo.ca/~bico/qsopt/beta/codes/PIC/qsopt.PIC.a"
10
+ QSOPT_H_URL = "https://www.math.uwaterloo.ca/~bico/qsopt/beta/codes/PIC/qsopt.h"
11
+ CONCORDE_URL = "https://www.math.uwaterloo.ca/tsp/concorde/downloads/codes/src/co031219.tgz"
12
+
13
+ concorde_path = "models/solvers/concorde"
14
+ concorde_src_path = f"{concorde_path}/src"
15
+ os.makedirs(concorde_src_path, exist_ok=True)
16
+ # download qsopt, which is a dependency library
17
+ print("Downloading QSOPT...", end=" ", flush=True)
18
+ qsopt_path = f"{concorde_src_path}/qsopt"
19
+ qsopt_a_path = f"{qsopt_path}/qsopt.a"
20
+ qsopt_h_path = f"{qsopt_path}/qsopt.h"
21
+ os.makedirs(qsopt_path, exist_ok=True)
22
+ urllib.request.urlretrieve(QSOPT_A_URL, qsopt_a_path)
23
+ urllib.request.urlretrieve(QSOPT_H_URL, qsopt_h_path)
24
+ print("done")
25
+
26
+ # download concorde tsp
27
+ print("Downloading Concorde TSP...", end=" ", flush=True)
28
+ concorde_tgz_path = f"{concorde_src_path}/concorde.tgz"
29
+ urllib.request.urlretrieve(CONCORDE_URL, concorde_tgz_path)
30
+ print("done")
31
+
32
+ # build concorde
33
+ _run("tar -xzf concorde.tgz", concorde_src_path)
34
+ _run("mv concorde/* .", concorde_src_path)
35
+ _run("rm -r concorde.tgz concorde", concorde_src_path)
36
+ cflags = "-fPIC -O2 -g"
37
+ datadir = os.path.abspath(qsopt_path)
38
+ cmd = f"CFLAGS='{cflags}' ./configure --prefix {datadir} --with-qsopt={datadir}"
39
+ _run(cmd, concorde_src_path)
40
+ _run("make", concorde_src_path)
41
+
42
+ def install_lkh():
43
+ LKH_URL = "http://webhotel4.ruc.dk/~keld/research/LKH-3/LKH-3.0.8.tgz"
44
+
45
+ lkh_path = "models/solvers/lkh"
46
+ lkh_src_path = f"{lkh_path}/src"
47
+ os.makedirs(lkh_src_path, exist_ok=True)
48
+
49
+ # download LKH
50
+ urllib.request.urlretrieve(LKH_URL, f"{lkh_src_path}/LKH-3.0.8.tgz")
51
+
52
+ # build LKH
53
+ _run("tar -xzf LKH-3.0.8.tgz", lkh_src_path)
54
+ _run("mv LKH-3.0.8/* .", lkh_src_path)
55
+ _run("rm -r LKH-3.0.8.tgz LKH-3.0.8", lkh_src_path)
56
+ _run("make", lkh_src_path)
57
+
58
+ if __name__ == "__main__":
59
+ import argparse
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--installed_solvers", default="all", type=str, help="Solvers: [all, concorde, lkh]")
62
+ args = parser.parse_args()
63
+
64
+ if args.installed_solvers == "all" or args.installed_solvers == "concorde":
65
+ install_concorde()
66
+
67
+ if args.installed_solvers == "all" or args.installed_solvers == "lkh":
68
+ install_lkh()
models/cf_generator.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+
4
+ class CFTourGenerator(nn.Module):
5
+ def __init__(self, cf_solver):
6
+ super().__init__()
7
+ self.solver = cf_solver
8
+ self.problem = cf_solver.problem
9
+
10
+ def forward(self, factual_tour, vehicle_id, cf_step, cf_next_node_id, node_feats, dist_matrix=None):
11
+ """
12
+ solve an input instance with visited edges fixed
13
+
14
+ Parameters
15
+ ----------
16
+ factual_tour: list [seq_length]
17
+ cf_step: int
18
+ cf_next_node_id: int
19
+ node_feats:
20
+
21
+ Returns
22
+ -------
23
+ cf_tour: np.array [seq_length]
24
+ """
25
+ fixed_paths = self.get_fixed_paths(factual_tour, vehicle_id, cf_step, cf_next_node_id)
26
+ cf_tours = self.solver.solve(node_feats, fixed_paths, dist_matrix=dist_matrix)
27
+ if cf_tours is None:
28
+ return
29
+ if (cf_step > 0):
30
+ for vehicle_id, cf_tour in enumerate(cf_tours):
31
+ if cf_next_node_id in cf_tour:
32
+ if cf_step == 1:
33
+ if cf_tour[1] != cf_next_node_id:
34
+ cf_tours[vehicle_id] = np.flipud(cf_tour)
35
+ break
36
+ else:
37
+ if (factual_tour[vehicle_id][1] != cf_tour[1]):
38
+ cf_tours[vehicle_id] = np.flipud(cf_tour) # make direction of the cf tour the same as factual one
39
+ break
40
+ print("aaaa", cf_tours)
41
+ return cf_tours
42
+
43
+ def get_fixed_paths(self, factual_tour, vehicle_id, cf_step, cf_next_node_id):
44
+ visited_paths = np.append(factual_tour[vehicle_id][:cf_step], cf_next_node_id)
45
+ return visited_paths
46
+
47
+ # def get_avail_edges(self, factual_tour, cf_step, cf_next_node_id):
48
+ # visited_paths = np.append(factual_tour[:cf_step], cf_next_node_id)
49
+ # avail_edges = []
50
+ # # add fixed edges
51
+ # for i in range(len(visited_paths) - 1):
52
+ # avail_edges.append([visited_paths[i], visited_paths[i + 1]])
53
+ # print(avail_edges)
54
+
55
+ # # add rest avaialbel edges
56
+ # num_nodes = np.max(factual_tour) + 1
57
+ # visited = np.array([0] * num_nodes)
58
+ # for id in visited_paths:
59
+ # visited[id] = 1
60
+ # visited[factual_tour[0]] = 0
61
+ # visited[cf_next_node_id] = 0
62
+ # mask = visited < 1
63
+ # node_id = np.arange(num_nodes)
64
+ # feasible_node_id = node_id[mask]
65
+ # for j in range(len(feasible_node_id) - 1):
66
+ # for i in range(j + 1, len(feasible_node_id)):
67
+ # if ((feasible_node_id[j] == factual_tour[0]) and (feasible_node_id[i] == cf_next_node_id)) or ((feasible_node_id[i] == factual_tour[0]) and (feasible_node_id[j] == cf_next_node_id)):
68
+ # continue
69
+ # avail_edges.append([feasible_node_id[j], feasible_node_id[i]])
70
+ # return np.array(avail_edges)
71
+
72
+ #-----------
73
+ # unit test
74
+ #-----------
75
+ if __name__ == "__main__":
76
+ import argparse
77
+ import random
78
+ import matplotlib.pyplot as plt
79
+ # FYI:
80
+ # - https://yu-nix.com/archives/python-path-get/
81
+ # - https://www.delftstack.com/ja/howto/python/python-get-parent-directory/
82
+ # - https://stackoverflow.com/questions/2817264/how-to-get-the-parent-dir-location
83
+ import os
84
+ import sys
85
+ CURR_DIR = os.path.dirname(os.path.abspath(__file__))
86
+ PARENT_DIR = os.path.abspath(os.path.join(CURR_DIR, os.pardir))
87
+ sys.path.append(PARENT_DIR)
88
+ from utils.util_vis import visualize_factual_and_cf_tours
89
+ from lkh.lkh import LKH
90
+ from models.ortools.ortools import ORTools
91
+ from data_generator.tsptw.tsptw_dataset import generate_tsptw_instance
92
+
93
+ parser = argparse.ArgumentParser(description='')
94
+ # general settings
95
+ parser.add_argument("--problem", type=str, default="tsptw")
96
+ parser.add_argument("--random_seed", type=int, default=1234)
97
+ parser.add_argument("--num_samples", type=int, default=5)
98
+ parser.add_argument("--num_nodes", type=int, default=100)
99
+ parser.add_argument("--coord_dim", type=int, default=2)
100
+ # LKH settings
101
+ parser.add_argument("--max_trials", type=int, default=1000)
102
+ parser.add_argument("--lkh_dir", type=str, default="lkh", help="Path to the binary of LKH")
103
+ parser.add_argument("--io_dir", type=str, default="lkh_io_files")
104
+ args = parser.parse_args()
105
+
106
+ # models
107
+ # cf_solver = LKH(args.problem, args.max_trials, args.random_seed, lkh_dir=args.lkh_dir, io_dir=args.io_dir)
108
+ cf_solver = ORTools(args.problem)
109
+ cf_generator = CFTourGenerator(cf_solver)
110
+
111
+ # dataset
112
+ if args.problem == "tsp":
113
+ np.random.seed(args.random_seed)
114
+ node_feats = np.random.uniform(size=[args.num_samples, args.num_nodes, args.coord_dim])
115
+ elif args.problem == "tsptw":
116
+ coords, time_window, grid_size = generate_tsptw_instance(num_nodes=args.num_nodes, grid_size=100, max_tw_gap=10, max_tw_size=1000, is_integer_instance=True, da_silva_style=True)
117
+ node_feats = np.concatenate([coords, time_window], -1)
118
+ node_feats = node_feats[None, :, :]
119
+
120
+ # function ot automatically generate couterfactual visit
121
+ def get_random_cf_visit(factual_tour, random_seed=1234):
122
+ # random.seed(random_seed)
123
+ num_nodes = np.max(factual_tour) + 1
124
+ step = random.randrange(len(factual_tour) - 2) # remove the last step (returning to the start-point)
125
+ visited = np.array([0] * num_nodes)
126
+ for i in range(step+1):
127
+ visited[factual_tour[i]] = 1
128
+ mask = visited < 1
129
+ node_id = np.arange(num_nodes)
130
+ feasible_node_id = node_id[mask]
131
+ cf_next_id = random.choice(feasible_node_id) # select counterfactual
132
+ return step, cf_next_id
133
+
134
+ for i in range(len(node_feats)):
135
+ # obtain a factual tour
136
+ factual_tour = cf_solver.solve(node_feats[i])
137
+
138
+ # counterfactual visit
139
+ cf_step, cf_next_node_id = get_random_cf_visit(factual_tour, random_seed=args.random_seed)
140
+
141
+ print(cf_step, cf_next_node_id)
142
+ # obtain a counterfactual tour
143
+ cf_tour = cf_generator(factual_tour, cf_step, cf_next_node_id, node_feats[i])
144
+
145
+ print(factual_tour)
146
+ print(cf_tour)
147
+
148
+ # visualize the factual and counterfactual tours
149
+ if args.problem == "tsp":
150
+ coords = node_feats[i]
151
+ elif args.problem == "tsptw":
152
+ coord_dim = 2
153
+ coords = node_feats[i, :, :coord_dim]
154
+ visualize_factual_and_cf_tours(factual_tour, cf_tour, coords, cf_step, f"test{i}.png")
155
+ break
models/classifiers/general_classifier.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import argparse
4
+ import json
5
+ import os
6
+ from models.classifiers.predictor import DecisionPredictor
7
+ from models.classifiers.meaningless_models import FixedClassPredictor, RandomPredictor
8
+ from models.classifiers.rule_based_models import kNearestPredictor
9
+ from models.classifiers.ground_truth.ground_truth import GroundTruth
10
+
11
+ class GeneralClassifier(nn.Module):
12
+ def __init__(self, problem, model_type):
13
+ super().__init__()
14
+ self.model_type = model_type
15
+ self.problem = problem
16
+ self.model = self.get_model(problem, model_type)
17
+
18
+ def change_model(self, problem, model_type):
19
+ if self.model_type != model_type or self.problem != problem:
20
+ self.model_type = model_type
21
+ self.problem = problem
22
+ self.model = self.get_model(problem, model_type)
23
+
24
+ def get_model(self, problem, model_type):
25
+ if model_type == "gnn":
26
+ model_path = "checkpoints/model_20230309_101058/model_epoch4.pth"
27
+ params = argparse.ArgumentParser()
28
+ model_dir = os.path.split(model_path)[0]
29
+ with open(f"{model_dir}/cmd_args.dat", "r") as f:
30
+ params.__dict__ = json.load(f)
31
+ model = DecisionPredictor(params.problem,
32
+ params.emb_dim,
33
+ params.num_mlp_layers,
34
+ params.num_classes,
35
+ params.dropout)
36
+ model.load_state_dict(torch.load(model_path))
37
+ return model
38
+ elif model_type == "gt(ortools)":
39
+ return GroundTruth(problem, solver_type="ortools")
40
+ elif model_type == "gt(lkh)":
41
+ return GroundTruth(problem, solver_type="lkh")
42
+ elif model_type == "gt(concorde)":
43
+ return GroundTruth(problem, solver_type="concorde")
44
+ elif model_type == "random":
45
+ return RandomPredictor(num_classes=2)
46
+ elif model_type == "fixed":
47
+ predicted_class = 0
48
+ return FixedClassPredictor(predicted_class=predicted_class, num_classes=2)
49
+ elif model_type == "knn":
50
+ k = 5
51
+ k_type = "num"
52
+ return kNearestPredictor(problem, k, k_type)
53
+ else:
54
+ assert False, f"Invalid model type: {model_type}"
55
+
56
+ def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
57
+ return self.model.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
58
+
59
+ def forward(self, inputs):
60
+ return self.model(inputs)
models/classifiers/ground_truth/ground_truth.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from models.classifiers.ground_truth.ground_truth_tsptw import GroundTruthTSPTW
5
+ from models.classifiers.ground_truth.ground_truth_pctsp import GroundTruthPCTSP
6
+ from models.classifiers.ground_truth.ground_truth_pctsptw import GroundTruthPCTSPTW
7
+ from models.classifiers.ground_truth.ground_truth_cvrp import GroundTruthCVRP
8
+ from models.classifiers.ground_truth.ground_truth_cvrptw import GroundTruthCVRPTW
9
+
10
+ class GroundTruth(nn.Module):
11
+ def __init__(self, problem, solver_type):
12
+ super().__init__()
13
+ self.problem = problem
14
+ self.solver_type = solver_type
15
+ if problem == "tsptw":
16
+ self.ground_truth = GroundTruthTSPTW(solver_type)
17
+ elif problem == "pctsp":
18
+ self.ground_truth = GroundTruthPCTSP(solver_type)
19
+ elif problem == "pctsptw":
20
+ self.ground_truth = GroundTruthPCTSPTW(solver_type)
21
+ elif problem == "cvrp":
22
+ self.ground_truth = GroundTruthCVRP(solver_type)
23
+ elif problem == "cvrptw":
24
+ self.ground_truth = GroundTruthCVRPTW(solver_type)
25
+ else:
26
+ raise NotImplementedError
27
+
28
+ def forward(self, inputs, annotation=False, parallel=False):
29
+ return self.ground_truth(inputs, annotation, parallel)
30
+
31
+ def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
32
+ return self.ground_truth.get_inputs(tour, first_explained_step, node_feats, dist_matrix)
33
+
34
+ def solve(self, step, input_tour, node_feats, instance_name=None):
35
+ return self.ground_truth.solve(step, input_tour, node_feats, instance_name)
models/classifiers/ground_truth/ground_truth_base.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ import multiprocessing
6
+ from models.solvers.general_solver import GeneralSolver
7
+ from utils.utils import calc_tour_length
8
+
9
+ def get_visited_mask(tour, step, node_feats, dist_matrix=None):
10
+ """
11
+ Visited nodes -> feasible, Unvisited nodes -> infeasible.
12
+ When solving a problem with visited_paths fixed, they should be included to the solution.
13
+ Therefore, visited nodes are set to feasible nodes.
14
+ """
15
+ if dist_matrix is not None:
16
+ num_nodes = len(dist_matrix)
17
+ else:
18
+ num_nodes = len(node_feats["coords"])
19
+ visited = np.isin(np.arange(num_nodes), tour[:step])
20
+ return visited
21
+
22
+ def get_tw_mask(tour, step, node_feats, dist_matrix=None):
23
+ """
24
+ Nodes whose tw exceeds current_time -> infeasible, otherwise -> feasible.
25
+
26
+ Parameters
27
+ ----------
28
+ tour: list [seq_length]
29
+ step: int
30
+ node_feats: dict of np.array
31
+
32
+ Returns
33
+ -------
34
+ mask_tw: np.array [num_nodes]
35
+ """
36
+ node_feats = node_feats.copy()
37
+ time_window = node_feats["time_window"]
38
+ if dist_matrix is not None:
39
+ num_nodes = len(dist_matrix)
40
+ curr_time = 0.0
41
+ not_exceed_tw = np.ones(num_nodes).astype(np.int32)
42
+ for i in range(1, step):
43
+ prev_id = tour[i - 1]
44
+ curr_id = tour[i]
45
+ travel_time = dist_matrix[prev_id, curr_id]
46
+ # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}"
47
+ if curr_time + travel_time < time_window[curr_id, 0]:
48
+ curr_time = time_window[curr_id, 0].copy()
49
+ else:
50
+ curr_time += travel_time
51
+ curr_time = curr_time + dist_matrix[tour[step-1]] # [num_nodes] TODO: check
52
+ else:
53
+ coords = node_feats["coords"]
54
+ num_nodes = len(coords)
55
+ curr_time = 0.0
56
+ not_exceed_tw = np.ones(num_nodes).astype(np.int32)
57
+ for i in range(1, step):
58
+ prev_id = tour[i - 1]
59
+ curr_id = tour[i]
60
+ travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id])
61
+ # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}"
62
+ if curr_time + travel_time < time_window[curr_id, 0]:
63
+ curr_time = time_window[curr_id, 0].copy()
64
+ else:
65
+ curr_time += travel_time
66
+ curr_time = curr_time + np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) # [num_nodes] TODO: check
67
+ not_exceed_tw[curr_time > time_window[:, 1]] = 0
68
+ not_exceed_tw = not_exceed_tw > 0
69
+ return not_exceed_tw
70
+
71
+ def get_cap_mask(tour, step, node_feats):
72
+ num_nodes = len(node_feats["coords"])
73
+ demands = node_feats["demand"]
74
+ remaining_cap = node_feats["capacity"].copy()
75
+ less_than_cap = np.ones(num_nodes).astype(np.int32)
76
+ for i in range(step):
77
+ remaining_cap -= demands[tour[i]]
78
+ less_than_cap[remaining_cap < demands] = 0
79
+ less_than_cap = less_than_cap > 0
80
+ return less_than_cap
81
+
82
+ def get_pc_mask(tour, step, node_feats):
83
+ """
84
+ Mask for Price collecting problems (e.g., PCTSP, PCTSPTW, PCCVRP, PCCVRPTW, ...)
85
+
86
+ Returns
87
+ -------
88
+ not_exceed_max_length
89
+ """
90
+ large_value = 1e+5
91
+ coords = node_feats["coords"]
92
+ max_length = (node_feats["max_length"] * large_value).astype(np.int64)
93
+ tour_length = 0
94
+ for i in range(1, step):
95
+ prev_id = tour[i - 1]
96
+ curr_id = tour[i]
97
+ tour_length += (np.linalg.norm(coords[prev_id] - coords[curr_id]) * large_value).astype(np.int64)
98
+ curr_to_next = (np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes]
99
+ next_to_depot = (np.linalg.norm(coords[tour[0]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes]
100
+ not_exceed_max_length = (tour_length + curr_to_next + next_to_depot) <= max_length # [num_nodes]
101
+ return not_exceed_max_length
102
+
103
+ def analyze_tour(tour, node_feats):
104
+ coords = node_feats["coords"]
105
+ time_window = node_feats["time_window"]
106
+ curr_time = 0
107
+ for i in range(1, len(tour)):
108
+ prev_id = tour[i - 1]
109
+ curr_id = tour[i]
110
+ travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id])
111
+ valid = curr_time + travel_time < time_window[curr_id, 1]
112
+ print(f"visit #{i}: {prev_id} -> {curr_id}, travel_time: {travel_time}, arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}, valid: {valid}")
113
+ if curr_time + travel_time < time_window[curr_id, 0]:
114
+ curr_time = time_window[curr_id, 0]
115
+ else:
116
+ curr_time += travel_time
117
+
118
+ FAIL_FLAG = -1
119
+ class GroundTruthBase(nn.Module):
120
+ def __init__(self, problem, compared_problems, solver_type):
121
+ """
122
+ Parameters
123
+ ----------
124
+
125
+ """
126
+ super().__init__()
127
+ self.problem = problem
128
+ self.compared_problems = compared_problems
129
+ self.num_compared_problems = len(compared_problems)
130
+ self.solver_type = solver_type
131
+ self.solvers = []
132
+ for i in range(self.num_compared_problems):
133
+ # TODO:
134
+ self.solvers.append(GeneralSolver(self.compared_problems[i], self.solver_type, scaling=False))
135
+
136
+ def forward(self, inputs, annotation=False, parallel=True):
137
+ """
138
+ Parameters
139
+ ----------
140
+ inputs: dict
141
+ tour: 2d list [num_vehicles x seq_length]
142
+ first_explained_step: int
143
+ node_feats: dict of np.array
144
+ annotation: bool
145
+ please set it True when annotating data
146
+
147
+ Returns
148
+ -------
149
+ labels:
150
+ probs: torch.tensor [batch_size (num_vehicles) x max_seq_length x num_classes]
151
+ """
152
+ input_tours = inputs["tour"]
153
+ node_feats = inputs["node_feats"]
154
+ dist_matrix = inputs["dist_matrix"]
155
+ first_explained_step = inputs["first_explained_step"]
156
+ num_vehicles = len(input_tours)
157
+ if annotation:
158
+ labels = [[] for _ in range(num_vehicles)]
159
+ for vehicle_id in range(num_vehicles):
160
+ input_tour = input_tours[vehicle_id]
161
+ # analyze_tour(input_tour, node_feats)
162
+ for step in range(first_explained_step + 1, len(input_tour)):
163
+ _, __, label = self.label_path(vehicle_id, step, input_tour, node_feats)
164
+ if label == FAIL_FLAG:
165
+ return
166
+ labels[vehicle_id].append((step, label))
167
+ return labels
168
+ else:
169
+ if parallel:
170
+ labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)]
171
+ num_cpus = os.cpu_count()
172
+ with multiprocessing.Pool(num_cpus) as pool:
173
+ for vehicle_id, step, label in pool.starmap(self.label_path, [(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix)
174
+ for vehicle_id in range(num_vehicles)
175
+ for step in range(first_explained_step+1, len(input_tours[vehicle_id]))]):
176
+ labels[vehicle_id][step-(first_explained_step+1)] = label
177
+ else:
178
+ labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)]
179
+ for vehicle_id in range(num_vehicles):
180
+ for step in range(first_explained_step+1, len(input_tours[vehicle_id])):
181
+ vehicle_id, step, label = self.label_path(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix)
182
+ labels[vehicle_id][step-(first_explained_step+1)] = label
183
+ # validate labels
184
+ for vehicle_id in range(num_vehicles):
185
+ assert (len(input_tours[vehicle_id]) - 1) == len(labels[vehicle_id]), f"vehicle_id={vehicle_id}, {input_tours}, {labels}"
186
+ return labels
187
+ # labels = [torch.LongTensor(label) for label in labels] # [num_vehicles x seq_length]
188
+ # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) # [num_vehicles x max_seq_length]
189
+ # probs = torch.zeros((labels.size(0), labels.size(1), self.num_compared_problems+1)) # [num_vehicles x max_seq_length x (num_compared_problems+1)]
190
+ # probs.scatter_(-1, labels.unsqueeze(-1).expand_as(probs), 1.0)
191
+ # return probs
192
+
193
+ def label_path(self, vehicle_id, step, input_tour, node_feats, dist_matrix=None):
194
+ compared_tour_list = [[] for _ in range(self.num_compared_problems)]
195
+ visited_path = input_tour[:step].copy()
196
+ new_node_id, new_node_feats, new_dist_matrix = self.get_feasible_nodes(input_tour, step, node_feats, dist_matrix)
197
+ new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path)))
198
+ for i in range(self.num_compared_problems):
199
+ # TODO: in CVRPTW / PCCVRPTW, need to modify classification of the first and last paths
200
+ compared_tours = self.solvers[i].solve(new_node_feats, new_visited_path, new_dist_matrix)
201
+ if compared_tours is None:
202
+ return vehicle_id, step, FAIL_FLAG
203
+ compared_tour = None
204
+ for compared_tour_tmp in compared_tours:
205
+ if new_visited_path[-1] in compared_tour_tmp:
206
+ compared_tour = compared_tour_tmp
207
+ break
208
+ assert compared_tour is not None, f"Found no appropriate vhiecle. {compared_tours}, {new_visited_path}"
209
+ compared_tour = np.array(list(map(lambda x: new_node_id[x], compared_tour)))
210
+ if (step > 0) and (compared_tour[1] != input_tour[1]):
211
+ compared_tour = np.flipud(compared_tour) # make direction of the cf tour the same as factual one
212
+ compared_tour_list[i] = compared_tour
213
+ # print("fixed_paths :", visited_path)
214
+ # print("input_tour :", input_tour)
215
+ # print("compared_tour:", compared_tour)
216
+ # print()
217
+ # annotation
218
+ label = self.get_label(input_tour, compared_tour_list, step)
219
+ return vehicle_id, step, label
220
+
221
+ def solve(self, step, input_tour, node_feats, instance_name=None):
222
+ compared_tours = {}
223
+ visited_path = input_tour[:step].copy()
224
+ new_node_id, new_node_feats = self.get_feasible_nodes(input_tour, step, node_feats)
225
+ new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path)))
226
+ for i, compared_problem in enumerate(self.compared_problems):
227
+ compared_tours[compared_problem] = self.solvers[i].solve(new_node_feats, new_visited_path, instance_name)
228
+ compared_tours[compared_problem] = list(map(lambda compared_tour: list(map(lambda x: new_node_id[x], compared_tour)), compared_tours[compared_problem]))
229
+ compared_tours[compared_problem] = list(map(lambda compared_tour: calc_tour_length(compared_tour, node_feats["coords"]), compared_tours[compared_problem]))
230
+ return compared_tours
231
+
232
+ def get_label(self, input_tour, compared_tours, step):
233
+ for i in range(self.num_compared_problems):
234
+ compared_tour = compared_tours[i]
235
+ if input_tour[step] == compared_tour[step]:
236
+ return i
237
+ return self.num_compared_problems
238
+
239
+ def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None):
240
+ input_features = {
241
+ "tour": tour,
242
+ "first_explained_step": first_explained_step,
243
+ "node_feats": node_feats,
244
+ "dist_matrix": dist_matrix
245
+ }
246
+ return input_features
247
+
248
+ def get_feasible_nodes(self, tour, step, node_feats, dist_matrix=None):
249
+ """
250
+ Parameters
251
+ ----------
252
+ tour: np.array [seq_length]
253
+ step: int
254
+ node_feats: np.array [num_nodes x node_dim]
255
+
256
+ Returns
257
+ -------
258
+ new_node_id: np.array [num_feasible_nodes]
259
+ new_node_feats: dict of np.array [num_feasible_nodes x coord_dim]
260
+ """
261
+ if dist_matrix is not None:
262
+ num_nodes = len(dist_matrix)
263
+ else:
264
+ num_nodes = len(node_feats["coords"])
265
+ mask = self.get_mask(tour, step, node_feats, dist_matrix)
266
+ node_id = np.arange(num_nodes)
267
+ new_node_id = node_id[mask].copy()
268
+ new_node_feats = {
269
+ key: node_feat[mask].copy()
270
+ if key in ["coords", "time_window", "demand", "penalties", "prizes"] else
271
+ node_feat.copy()
272
+ for key, node_feat in node_feats.items()
273
+ }
274
+ if dist_matrix is not None:
275
+ delete_id = node_id[~mask]
276
+ new_dist_matrix = np.delete(np.delete(dist_matrix, delete_id, 0), delete_id, 1)
277
+ else:
278
+ new_dist_matrix = None
279
+ return new_node_id, new_node_feats, new_dist_matrix
280
+
281
+ def get_mask(self, tour, step, node_feats, dist_matrix=None):
282
+ raise NotImplementedError
283
+
284
+ def check_feasibility(self, tour, node_feats):
285
+ raise NotImplementedError
models/classifiers/ground_truth/ground_truth_cvrp.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
2
+ from models.classifiers.ground_truth.ground_truth_base import get_cap_mask, get_visited_mask
3
+
4
+ class GroundTruthCVRP(GroundTruthBase):
5
+ def __init__(self, solver_type):
6
+ problem = "cvrp"
7
+ compared_problems = ["tsp"]
8
+ super().__init__(problem, compared_problems, solver_type)
9
+
10
+ # @override
11
+ def get_mask(self, tour, step, node_feats):
12
+ visited = get_visited_mask(tour, step, node_feats)
13
+ less_than_cap = get_cap_mask(tour, step, node_feats)
14
+ return visited | less_than_cap
models/classifiers/ground_truth/ground_truth_cvrptw.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
2
+ from models.classifiers.ground_truth.ground_truth_base import get_cap_mask, get_visited_mask, get_tw_mask
3
+
4
+ class GroundTruthCVRPTW(GroundTruthBase):
5
+ def __init__(self, solver_type):
6
+ problem = "cvrptw"
7
+ compared_problems = ["tsp", "cvrp"]
8
+ super().__init__(problem, compared_problems, solver_type)
9
+
10
+ # @override
11
+ def get_mask(self, tour, step, node_feats):
12
+ visited = get_visited_mask(tour, step, node_feats)
13
+ less_than_cap = get_cap_mask(tour, step, node_feats)
14
+ not_exceed_tw = get_tw_mask(tour, step, node_feats)
15
+ return visited | (less_than_cap & not_exceed_tw)
models/classifiers/ground_truth/ground_truth_pctsp.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
3
+ from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_pc_mask
4
+
5
+ class GroundTruthPCTSP(GroundTruthBase):
6
+ def __init__(self, solver_type):
7
+ problem = "pctsp"
8
+ compared_problems = ["tsp"]
9
+ super().__init__(problem, compared_problems, solver_type)
10
+
11
+ # @override
12
+ def get_mask(self, tour, step, node_feats):
13
+ # visited = get_visited_mask(tour, step, node_feats)
14
+ # not_exceed_max_length = get_pc_mask(tour, step, node_feats)
15
+ num_nodes = len(node_feats["coords"])
16
+ return np.full(num_nodes, True)
models/classifiers/ground_truth/ground_truth_pctsptw.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
3
+ from models.classifiers.ground_truth.ground_truth_base import get_visited_mask, get_tw_mask
4
+
5
+ class GroundTruthPCTSPTW(GroundTruthBase):
6
+ def __init__(self, solver_type):
7
+ problem = "pctsptw"
8
+ compared_problems = ["tsp", "pctsp"]
9
+ super().__init__(problem, compared_problems, solver_type)
10
+
11
+ # @override
12
+ def get_mask(self, tour, step, node_feats):
13
+ visited = get_visited_mask(tour, step, node_feats)
14
+ not_exceed_tw = get_tw_mask(tour, step, node_feats)
15
+ return visited | not_exceed_tw
models/classifiers/ground_truth/ground_truth_tsptw.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.classifiers.ground_truth.ground_truth_base import GroundTruthBase
2
+ from models.classifiers.ground_truth.ground_truth_base import get_tw_mask, get_visited_mask
3
+
4
+ class GroundTruthTSPTW(GroundTruthBase):
5
+ def __init__(self, solver_type):
6
+ problem = "tsptw"
7
+ compared_problems = ["tsp"]
8
+ super().__init__(problem, compared_problems, solver_type)
9
+
10
+ # @override
11
+ def get_mask(self, tour, step, node_feats, dist_matrix=None):
12
+ visited = get_visited_mask(tour, step, node_feats, dist_matrix)
13
+ not_exceed_tw = get_tw_mask(tour, step, node_feats, dist_matrix)
14
+ return visited | not_exceed_tw
models/classifiers/meaningless_models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RandomPredictor(nn.Module):
5
+ def __init__(self, num_classes):
6
+ super().__init__()
7
+ self.num_classes = num_classes
8
+
9
+ def forward(self, inputs):
10
+ """
11
+ Parameters
12
+ ----------
13
+ inputs: int or dict
14
+ batch_size or dict of input features
15
+
16
+ Returns
17
+ -------
18
+ probs: torch.tensor [batch_size x num_classes]
19
+ """
20
+ batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
21
+ ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes))
22
+ probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
23
+ probs.scatter_(-1, ranom_index, 1.0)
24
+ return probs
25
+
26
+ def get_inputs(self, tour, first_explained_step, node_feats):
27
+ return len(tour[first_explained_step:-1])
28
+
29
+ class FixedClassPredictor(nn.Module):
30
+ def __init__(self, predicted_class, num_classes):
31
+ """
32
+ Paramters
33
+ ---------
34
+ predicted_class: int
35
+ a class that this predictor always predicts
36
+ num_classes: int
37
+ number of classes
38
+ """
39
+ super().__init__()
40
+ self.predicted_class = predicted_class
41
+ self.num_classes = num_classes
42
+ assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}."
43
+
44
+ def forward(self, inputs):
45
+ """
46
+ Parameters
47
+ ----------
48
+ inputs: int or dict
49
+ batch_size or dict of input features
50
+
51
+ Returns
52
+ -------
53
+ probs: torch.tensor [batch_size x num_classes]
54
+ """
55
+ batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
56
+ index = torch.full((batch_size, self.num_classes), self.predicted_class)
57
+ probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
58
+ probs.scatter_(-1, index, 1.0)
59
+ return probs
60
+
61
+ def get_inputs(self, tour, first_explained_step, node_feats):
62
+ return len(tour[first_explained_step:-1])
models/classifiers/nn_classifiers/attention_graph_encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ class AttentionGraphEncoder(nn.Module):
8
+ def __init__(self, coord_dim, node_dim, state_dim, emb_dim, dropout):
9
+ super().__init__()
10
+ self.coord_dim = coord_dim
11
+ self.node_dim = node_dim
12
+ self.emb_dim = emb_dim
13
+ self.state_dim = state_dim
14
+ self.norm_factor = 1 / math.sqrt(emb_dim)
15
+
16
+ # initial embedding
17
+ self.init_linear_node = nn.Linear(node_dim, emb_dim)
18
+ self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
19
+ if state_dim > 0:
20
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
21
+
22
+ # An attention layer
23
+ self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
24
+ self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
25
+ self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
26
+
27
+ # Dropout
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ self.reset_parameters()
31
+
32
+ def reset_parameters(self):
33
+ for param in self.parameters():
34
+ stdv = 1. / math.sqrt(param.size(-1))
35
+ param.data.uniform_(-stdv, stdv)
36
+
37
+ def forward(self, inputs):
38
+ """
39
+ Paramters
40
+ ---------
41
+ inputs: dict
42
+ curr_node_id: torch.LongTensor [batch_size x 1]
43
+ next_node_id: torch.LongTensor [batch_size x 1]
44
+ node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
45
+ mask: torch.LongTensor [batch_size x num_nodes]
46
+ state: torch.FloatTensor [batch_size x state_dim]
47
+
48
+ Returns
49
+ -------
50
+ h: torch.tensor [batch_size x emb_dim]
51
+ graph embeddings
52
+ """
53
+ #----------------
54
+ # input features
55
+ #----------------
56
+ curr_node_id = inputs["curr_node_id"]
57
+ next_node_id = inputs["next_node_id"]
58
+ node_feat = inputs["node_feats"]
59
+ mask = inputs["mask"]
60
+ state = inputs["state"]
61
+
62
+ #---------------------------
63
+ # initial linear projection
64
+ #---------------------------
65
+ node_emb = self.init_linear_node(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
66
+ depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
67
+ new_node_feat = torch.cat((depot_emb, node_emb), 1) # [batch_size x num_nodes x emb_dim]
68
+ new_node_feat = self.dropout(new_node_feat)
69
+
70
+ #---------------
71
+ # preprocessing
72
+ #---------------
73
+ batch_size = curr_node_id.size(0)
74
+ curr_emb = new_node_feat.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
75
+ next_emb = new_node_feat.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
76
+ if state is not None and self.state_dim > 0:
77
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
78
+ input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
79
+ else:
80
+ input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
81
+ input_kv = torch.cat((curr_emb.expand_as(new_node_feat), new_node_feat), -1) # [batch_size x num_nodes x (2*emb_dim)]
82
+
83
+ #--------------------
84
+ # An attention layer
85
+ #--------------------
86
+ q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
87
+ k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
88
+ v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
89
+ compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
90
+ compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
91
+ attn = torch.softmax(compatibility, dim=-1)
92
+ h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
93
+ h = h.squeeze(1) # [batch_size x emb_dim]
94
+
95
+ return h
models/classifiers/nn_classifiers/decoders/lstm_decoder.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ class LSTMDecoder(nn.Module):
8
+ def __init__(self, emb_dim, num_mlp_layers, num_classes, dropout):
9
+ super().__init__()
10
+ self.num_mlp_layers = num_mlp_layers
11
+
12
+ # LSTM
13
+ self.lstm = nn.LSTM(emb_dim, emb_dim, batch_first=True)
14
+
15
+ # Decoder (MLP)
16
+ self.mlp = nn.ModuleList()
17
+ for _ in range(num_mlp_layers):
18
+ self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
19
+ self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
20
+
21
+ # Dropout
22
+ self.dropout = nn.Dropout(dropout)
23
+
24
+ # Initializing weights
25
+ self.reset_parameters()
26
+
27
+ def reset_parameters(self):
28
+ for param in self.parameters():
29
+ stdv = 1. / math.sqrt(param.size(-1))
30
+ param.data.uniform_(-stdv, stdv)
31
+
32
+ def forward(self, graph_emb):
33
+ """
34
+ Paramters
35
+ ---------
36
+ graph_emb: torch.tensor [batch_size x max_seq_length x emb_dim]
37
+
38
+ Returns
39
+ -------
40
+ probs: torch.tensor [batch_size x max_seq_length x num_classes]
41
+ probabilities of classes
42
+ """
43
+ #---------------
44
+ # LSTM encoding
45
+ #---------------
46
+ h, _ = self.lstm(graph_emb) # [batch_size x max_seq_length x emb_dim]
47
+
48
+ #----------
49
+ # Decoding
50
+ #----------
51
+ for i in range(self.num_mlp_layers):
52
+ h = self.dropout(h)
53
+ h = torch.relu(self.mlp[i](h))
54
+ h = self.dropout(h)
55
+ logits = self.mlp[-1](h)
56
+ probs = F.log_softmax(logits, dim=-1)
57
+ return probs
models/classifiers/nn_classifiers/decoders/mha_decoder.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
8
+ super().__init__()
9
+ self.dropout = nn.Dropout(p=dropout)
10
+
11
+ position = torch.arange(max_len).unsqueeze(1)
12
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
13
+ pe = torch.zeros(1, max_len, d_model) # batch_first
14
+ pe[0, :, 0::2] = torch.sin(position * div_term)
15
+ pe[0, :, 1::2] = torch.cos(position * div_term)
16
+ self.register_buffer('pe', pe)
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ """
20
+ Arguments:
21
+ x: Tensor, shape ``[batch_size x max_seq_length x embedding_dim]``
22
+ """
23
+ x = x + self.pe[:, :x.size(1), :]
24
+ return self.dropout(x)
25
+
26
+ class SelfMHADecoder(nn.Module):
27
+ def __init__(self, emb_dim: int, num_heads: int, num_mha_layers: int, num_classes: int, dropout: float, pos_encoder: str = None, max_len: int = 100):
28
+ super().__init__()
29
+ self.num_mha_layers = num_mha_layers
30
+
31
+ # positional encoding
32
+ self.pos_encoder_type = pos_encoder
33
+ if pos_encoder == "sincos":
34
+ self.pos_encoder = PositionalEncoding(d_model=emb_dim, dropout=dropout, max_len=max_len)
35
+
36
+ # MHA blocks
37
+ mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim,
38
+ nhead=num_heads,
39
+ dim_feedforward=emb_dim,
40
+ dropout=dropout,
41
+ batch_first=True)
42
+ self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)
43
+
44
+ # linear projection for adjusting out_dim to num_classes
45
+ self.out_linear = nn.Linear(emb_dim, num_classes, bias=True)
46
+
47
+ # Initializing weights
48
+ self.reset_parameters()
49
+
50
+ def reset_parameters(self):
51
+ for param in self.parameters():
52
+ stdv = 1. / math.sqrt(param.size(-1))
53
+ param.data.uniform_(-stdv, stdv)
54
+
55
+ def forward(self, edge_emb):
56
+ """
57
+ Paramters
58
+ ---------
59
+ graph_emb: torch.tensor [batch_size x max_seq_length x emb_dim]
60
+
61
+ Returns
62
+ -------
63
+ probs: torch.tensor [batch_size x max_seq_length x num_classes]
64
+ probabilities of classes
65
+ """
66
+ #---------------
67
+ # MHA decoding
68
+ #---------------
69
+ if self.pos_encoder_type == "sincos":
70
+ edge_emb = self.pos_encoder(edge_emb)
71
+ h = self.mha(edge_emb, is_causal=True) # [batch_size x max_seq_length x emb_dim]
72
+ logits = self.out_linear(h)
73
+ probs = F.log_softmax(logits, dim=-1)
74
+ return probs
models/classifiers/nn_classifiers/decoders/mlp_decoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ class MLPDecoder(nn.Module):
8
+ def __init__(self, emb_dim, num_mlp_layers, num_classes, dropout):
9
+ super().__init__()
10
+ self.num_mlp_layers = num_mlp_layers
11
+
12
+ # Decoder (MLP)
13
+ self.mlp = nn.ModuleList()
14
+ for _ in range(num_mlp_layers):
15
+ self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
16
+ self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
17
+
18
+ # Dropout
19
+ self.dropout = nn.Dropout(dropout)
20
+
21
+ # Initializing weights
22
+ self.reset_parameters()
23
+
24
+ def reset_parameters(self):
25
+ for param in self.parameters():
26
+ stdv = 1. / math.sqrt(param.size(-1))
27
+ param.data.uniform_(-stdv, stdv)
28
+
29
+ def forward(self, graph_emb):
30
+ """
31
+ Paramters
32
+ ---------
33
+ graph_emb: torch.tensor [batch_size x emb_dim]
34
+
35
+ Returns
36
+ -------
37
+ probs: torch.tensor [batch_size x num_classes]
38
+ probabilities of classes
39
+ """
40
+ #----------
41
+ # Decoding
42
+ #----------
43
+ h = graph_emb
44
+ for i in range(self.num_mlp_layers):
45
+ h = self.dropout(h)
46
+ h = torch.relu(self.mlp[i](h))
47
+ h = self.dropout(h)
48
+ logits = self.mlp[-1](h)
49
+ probs = F.log_softmax(logits, dim=-1)
50
+ return probs
models/classifiers/nn_classifiers/encoders/attn_edge_encoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class AttentionEdgeEncoder(nn.Module):
6
+ def __init__(self, state_dim, emb_dim, dropout):
7
+ super().__init__()
8
+ self.state_dim = state_dim
9
+ self.emb_dim = emb_dim
10
+ self.norm_factor = 1 / math.sqrt(emb_dim)
11
+
12
+ # initial embedding for state
13
+ if state_dim > 0:
14
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
15
+
16
+ # An attention layer
17
+ self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
18
+ self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
19
+ self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
20
+
21
+ # out linear layer
22
+ self.out_linear = nn.Linear(emb_dim, emb_dim)
23
+
24
+ # Dropout
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ self.reset_parameters()
28
+
29
+ def reset_parameters(self):
30
+ for param in self.parameters():
31
+ stdv = 1. / math.sqrt(param.size(-1))
32
+ param.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, inputs, node_emb):
35
+ """
36
+ Paramters
37
+ ---------
38
+ inputs: dict
39
+ curr_node_id: torch.LongTensor [batch_size]
40
+ next_node_id: torch.LongTensor [batch_size]
41
+ mask: torch.LongTensor [batch_size x num_nodes]
42
+ state: torch.FloatTensor [batch_size x state_dim]
43
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
44
+ node embeddings obtained from the node encoder
45
+
46
+ Returns
47
+ -------
48
+ h: torch.tensor [batch_size x emb_dim]
49
+ edge embeddings
50
+ """
51
+ curr_node_id = inputs["curr_node_id"]
52
+ next_node_id = inputs["next_node_id"]
53
+ mask = inputs["mask"]
54
+ state = inputs["state"]
55
+ batch_size = curr_node_id.size(0)
56
+
57
+ #--------------------------------
58
+ # generate queries, keys, values
59
+ #--------------------------------
60
+ node_emb = self.dropout(node_emb)
61
+ curr_emb = node_emb.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
62
+ next_emb = node_emb.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
63
+ if state is not None and self.state_dim > 0:
64
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
65
+ input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
66
+ else:
67
+ input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
68
+ input_kv = torch.cat((curr_emb.expand_as(node_emb), node_emb), -1) # [batch_size x num_nodes x (2*emb_dim)]
69
+
70
+ #--------------------
71
+ # An attention layer
72
+ #--------------------
73
+ q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
74
+ k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
75
+ v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
76
+ compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
77
+ compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
78
+ attn = torch.softmax(compatibility, dim=-1)
79
+ h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
80
+ h = h.squeeze(1) # [batch_size x emb_dim]
81
+ return self.out_linear(h) + q.squeeze(1)
models/classifiers/nn_classifiers/encoders/concat_edge_encoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class ConcatEdgeEncoder(nn.Module):
6
+ def __init__(self, state_dim, emb_dim, dropout):
7
+ super().__init__()
8
+ self.state_dim = state_dim
9
+ self.emb_dim = emb_dim
10
+ self.norm_factor = 1 / math.sqrt(emb_dim)
11
+
12
+ # initial embedding for state
13
+ if state_dim > 0:
14
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
15
+
16
+ # out linear layer
17
+ self.out_linear = nn.Linear((2 + int(state_dim > 0)) * emb_dim, emb_dim)
18
+
19
+ # Dropout
20
+ self.dropout = nn.Dropout(dropout)
21
+
22
+ self.reset_parameters()
23
+
24
+ def reset_parameters(self):
25
+ for param in self.parameters():
26
+ stdv = 1. / math.sqrt(param.size(-1))
27
+ param.data.uniform_(-stdv, stdv)
28
+
29
+ def forward(self, inputs, node_emb):
30
+ """
31
+ Paramters
32
+ ---------
33
+ inputs: dict
34
+ curr_node_id: torch.LongTensor [batch_size]
35
+ next_node_id: torch.LongTensor [batch_size]
36
+ mask: torch.LongTensor [batch_size x num_nodes]
37
+ state: torch.FloatTensor [batch_size x state_dim]
38
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
39
+ node embeddings obtained from the node encoder
40
+
41
+ Returns
42
+ -------
43
+ h: torch.tensor [batch_size x emb_dim]
44
+ edge embeddings
45
+ """
46
+ curr_node_id = inputs["curr_node_id"]
47
+ next_node_id = inputs["next_node_id"]
48
+ state = inputs["state"]
49
+ batch_size = curr_node_id.size(0)
50
+
51
+ #--------------------------------
52
+ # generate queries, keys, values
53
+ #--------------------------------
54
+ node_emb = self.dropout(node_emb)
55
+ curr_emb = node_emb.gather(1, curr_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
56
+ next_emb = node_emb.gather(1, next_node_id[:, None, None].expand(batch_size, 1, self.emb_dim))
57
+ if state is not None and self.state_dim > 0:
58
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
59
+ edge_emb = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
60
+ else:
61
+ edge_emb = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
62
+ edge_emb = edge_emb.squeeze(1) # [batch_size x (2*emb_dim)]
63
+ return self.out_linear(edge_emb)
models/classifiers/nn_classifiers/encoders/max_readout.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class MaxReadout(nn.Module):
6
+ def __init__(self, state_dim, emb_dim, dropout):
7
+ super().__init__()
8
+ self.state_dim = state_dim
9
+ self.emb_dim = emb_dim
10
+
11
+ # initial embedding for state
12
+ if state_dim > 0:
13
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
14
+
15
+ # out linear layer
16
+ self.out_linear = nn.Linear((1 + int(state_dim > 0))*emb_dim, emb_dim)
17
+
18
+ # Dropout
19
+ self.dropout = nn.Dropout(dropout)
20
+
21
+ self.reset_parameters()
22
+
23
+ def reset_parameters(self):
24
+ for param in self.parameters():
25
+ stdv = 1. / math.sqrt(param.size(-1))
26
+ param.data.uniform_(-stdv, stdv)
27
+
28
+ def forward(self, inputs, node_emb):
29
+ """
30
+ Paramters
31
+ ---------
32
+ inputs: dict
33
+ mask: torch.LongTensor [batch_size x num_nodes]
34
+ state: torch.FloatTensor [batch_size x state_dim]
35
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
36
+ node embeddings obtained from the node encoder
37
+
38
+ Returns
39
+ -------
40
+ h: torch.tensor [batch_size x emb_dim]
41
+ graph embeddings
42
+ """
43
+ mask = inputs["mask"]
44
+ state = inputs["state"]
45
+ node_emb = self.dropout(node_emb)
46
+
47
+ # pooling with a mask
48
+ mask = mask.unsqueeze(-1).expand_as(node_emb)
49
+ node_emb = node_emb * mask
50
+ h, _ = torch.max(node_emb, dim=1) # [batch_size x emb_dim]
51
+
52
+ # out linear layer
53
+ if state is not None and self.state_dim > 0:
54
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
55
+ h = torch.cat((h, state_emb), -1) # [batch_size x (2*emb_dim)]
56
+
57
+ return self.out_linear(h)
models/classifiers/nn_classifiers/encoders/mean_readout.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class MeanReadout(nn.Module):
6
+ def __init__(self, state_dim, emb_dim, dropout):
7
+ super().__init__()
8
+ self.state_dim = state_dim
9
+ self.emb_dim = emb_dim
10
+
11
+ # initial embedding for state
12
+ if state_dim > 0:
13
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
14
+
15
+ # out linear layer
16
+ self.out_linear = nn.Linear((1 + int(state_dim > 0))*emb_dim, emb_dim)
17
+
18
+ # Dropout
19
+ self.dropout = nn.Dropout(dropout)
20
+
21
+ self.reset_parameters()
22
+
23
+ def reset_parameters(self):
24
+ for param in self.parameters():
25
+ stdv = 1. / math.sqrt(param.size(-1))
26
+ param.data.uniform_(-stdv, stdv)
27
+
28
+ def forward(self, inputs, node_emb):
29
+ """
30
+ Paramters
31
+ ---------
32
+ inputs: dict
33
+ mask: torch.LongTensor [batch_size x num_nodes]
34
+ state: torch.FloatTensor [batch_size x state_dim]
35
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
36
+ node embeddings obtained from the node encoder
37
+
38
+ Returns
39
+ -------
40
+ h: torch.tensor [batch_size x emb_dim]
41
+ graph embeddings
42
+ """
43
+ mask = inputs["mask"]
44
+ state = inputs["state"]
45
+ node_emb = self.dropout(node_emb)
46
+
47
+ # pooling with a mask
48
+ mask = mask.unsqueeze(-1).expand_as(node_emb)
49
+ node_emb = node_emb * mask
50
+ h = torch.mean(node_emb, dim=1) # [batch_size x emb_dim]
51
+
52
+ # out linear layer
53
+ if state is not None and self.state_dim > 0:
54
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
55
+ h = torch.cat((h, state_emb), -1) # [batch_size x (2*emb_dim)]
56
+
57
+ return self.out_linear(h)
models/classifiers/nn_classifiers/encoders/mha_node_encoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class SelfMHANodeEncoder(nn.Module):
6
+ def __init__(self, coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout):
7
+ super().__init__()
8
+ self.coord_dim = coord_dim
9
+ self.node_dim = node_dim
10
+ self.emb_dim = emb_dim
11
+ self.num_mha_layers = num_mha_layers
12
+
13
+ # initial embedding
14
+ self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
15
+ self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
16
+
17
+ # MHA Encoder (w/o positional encoding)
18
+ mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim,
19
+ nhead=num_heads,
20
+ dim_feedforward=emb_dim,
21
+ dropout=dropout,
22
+ batch_first=True)
23
+ self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)
24
+
25
+ # Initializing weights
26
+ self.reset_parameters()
27
+
28
+ def reset_parameters(self):
29
+ for param in self.parameters():
30
+ stdv = 1. / math.sqrt(param.size(-1))
31
+ param.data.uniform_(-stdv, stdv)
32
+
33
+ def forward(self, inputs):
34
+ """
35
+ Paramters
36
+ ---------
37
+ inputs: dict
38
+ node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
39
+
40
+ Returns
41
+ -------
42
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
43
+ node embeddings
44
+ """
45
+ #----------------
46
+ # input features
47
+ #----------------
48
+ node_feat = inputs["node_feats"]
49
+
50
+ #------------------------------------------------------------------------
51
+ # initial linear projection for adjusting dimensions of locs & the depot
52
+ #------------------------------------------------------------------------
53
+ # node_feat = self.dropout(node_feat)
54
+ loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
55
+ depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
56
+ node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
57
+
58
+ #--------------
59
+ # MLP encoding
60
+ #--------------
61
+ node_emb = self.mha(node_emb) # [batch_size x num_nodes x emb_dim]
62
+
63
+ return node_emb
models/classifiers/nn_classifiers/encoders/mlp_node_encoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class MLPNodeEncoder(nn.Module):
6
+ def __init__(self, coord_dim, node_dim, emb_dim, num_mlp_layers, dropout):
7
+ super().__init__()
8
+ self.coord_dim = coord_dim
9
+ self.node_dim = node_dim
10
+ self.emb_dim = emb_dim
11
+ self.num_mlp_layers = num_mlp_layers
12
+
13
+ # initial embedding
14
+ self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
15
+ self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
16
+
17
+ # MLP Encoder
18
+ self.mlp = nn.ModuleList()
19
+ for _ in range(num_mlp_layers):
20
+ self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
21
+
22
+ # Dropout
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ def reset_parameters(self):
26
+ for param in self.parameters():
27
+ stdv = 1. / math.sqrt(param.size(-1))
28
+ param.data.uniform_(-stdv, stdv)
29
+
30
+ def forward(self, inputs):
31
+ """
32
+ Paramters
33
+ ---------
34
+ inputs: dict
35
+ node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
36
+
37
+ Returns
38
+ -------
39
+ node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
40
+ node embeddings
41
+ """
42
+ #----------------
43
+ # input features
44
+ #----------------
45
+ node_feat = inputs["node_feats"]
46
+
47
+ #------------------------------------------------------------------------
48
+ # initial linear projection for adjusting dimensions of locs & the depot
49
+ #------------------------------------------------------------------------
50
+ # node_feat = self.dropout(node_feat)
51
+ loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
52
+ depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
53
+ node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
54
+
55
+ #--------------
56
+ # MLP encoding
57
+ #--------------
58
+ for i in range(self.num_mlp_layers):
59
+ # node_emb = self.dropout(node_emb)
60
+ node_emb = self.mlp[i](node_emb)
61
+ if i != self.num_mlp_layers - 1:
62
+ node_emb = torch.relu(node_emb)
63
+ return node_emb
models/classifiers/nn_classifiers/nn_classifier.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ # Node encoder
7
+ from models.classifiers.nn_classifiers.encoders.mlp_node_encoder import MLPNodeEncoder
8
+ from models.classifiers.nn_classifiers.encoders.mha_node_encoder import SelfMHANodeEncoder
9
+ # Edge encoder
10
+ from models.classifiers.nn_classifiers.encoders.attn_edge_encoder import AttentionEdgeEncoder
11
+ from models.classifiers.nn_classifiers.encoders.concat_edge_encoder import ConcatEdgeEncoder
12
+ # Decoder
13
+ from models.classifiers.nn_classifiers.decoders.lstm_decoder import LSTMDecoder
14
+ from models.classifiers.nn_classifiers.decoders.mlp_decoder import MLPDecoder
15
+ from models.classifiers.nn_classifiers.decoders.mha_decoder import SelfMHADecoder
16
+
17
+ # data loader
18
+ from utils.data_utils.tsptw_dataset import load_tsptw_sequentially
19
+ from utils.data_utils.pctsp_dataset import load_pctsp_sequentially
20
+ from utils.data_utils.pctsptw_dataset import load_pctsptw_sequentially
21
+ from utils.data_utils.cvrp_dataset import load_cvrp_sequentially
22
+
23
+ NODE_ENC_LIST = ["mlp", "mha"]
24
+ EDGE_ENC_LIST = ["concat", "attn"]
25
+ DEC_LIST = ["mlp", "lstm", "mha"]
26
+
27
+ class NNClassifier(nn.Module):
28
+ def __init__(self,
29
+ problem: str,
30
+ node_enc_type: str,
31
+ edge_enc_type: str,
32
+ dec_type: str,
33
+ emb_dim: int,
34
+ num_enc_mlp_layers: int,
35
+ num_dec_mlp_layers: int,
36
+ num_classes: int,
37
+ dropout: float,
38
+ pos_encoder: str = "sincos"):
39
+ super().__init__()
40
+ self.problem = problem
41
+ self.node_enc_type = node_enc_type
42
+ self.edge_enc_type = edge_enc_type
43
+ self.dec_type = dec_type
44
+ assert node_enc_type in NODE_ENC_LIST, f"Invalid enc_type. select from {NODE_ENC_LIST}"
45
+ assert dec_type in DEC_LIST, f"Invalid dec_type. select from {DEC_LIST}"
46
+ self.is_sequential = True if dec_type in ["lstm", "mha"] else False
47
+ coord_dim = 2 # only support 2d problem
48
+ if problem == "tsptw":
49
+ node_dim = 4 # coords (2) + time window (2)
50
+ state_dim = 1 # current time (1)
51
+ elif problem == "pctsp":
52
+ node_dim = 4 # coords (2) + prize (1) + penalty (1)
53
+ state_dim = 2 # current prize (1) + current penalty (1)
54
+ elif problem == "pctsptw":
55
+ node_dim = 6 # coords (2) + prize (1) + penalty (1) + time window (2)
56
+ state_dim = 3 # current prize (1) + current penalty (1) + current time (1)
57
+ elif problem == "cvrp":
58
+ node_dim = 3 # coords (2) + demand (1)
59
+ state_dim = 1 # remaining capacity (1)
60
+ else:
61
+ NotImplementedError
62
+
63
+ #----------------
64
+ # Graph encoding
65
+ #----------------
66
+ # Node encoder
67
+ if node_enc_type == "mlp":
68
+ self.node_enc = MLPNodeEncoder(coord_dim, node_dim, emb_dim, num_enc_mlp_layers, dropout)
69
+ elif node_enc_type == "mha":
70
+ num_heads = 8
71
+ num_mha_layers = 2
72
+ self.node_enc = SelfMHANodeEncoder(coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout)
73
+ else:
74
+ raise NotImplementedError
75
+
76
+ # Readout
77
+ if edge_enc_type == "concat":
78
+ self.readout = ConcatEdgeEncoder(state_dim, emb_dim, dropout)
79
+ elif edge_enc_type == "attn":
80
+ self.readout = AttentionEdgeEncoder(state_dim, emb_dim, dropout)
81
+ else:
82
+ raise NotImplementedError
83
+
84
+ #------------------------
85
+ # Classification Decoder
86
+ #------------------------
87
+ if dec_type == "mlp":
88
+ self.decoder = MLPDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
89
+ elif dec_type == "lstm":
90
+ self.decoder = LSTMDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
91
+ elif dec_type == "mha":
92
+ num_heads = 8
93
+ num_mha_layers = 2
94
+ self.decoder = SelfMHADecoder(emb_dim, num_heads, num_mha_layers, num_classes, dropout, pos_encoder)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, inputs):
99
+ """
100
+ Paramters
101
+ ---------
102
+ inputs: dict
103
+ curr_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
104
+ next_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
105
+ node_feat: torch.FloatTensor [batch_size x max_seq_length x num_nodes x node_dim] if self.sequential else [batch_size x num_nodes x node_dim]
106
+ mask: torch.LongTensor [batch_size x max_seq_length x num_nodes] if self.sequential else [batch_size x num_nodes]
107
+ state: torch.FloatTensor [batch_size x max_seq_length x state_dim] if self.sequential else [batch_size x state_dim]
108
+
109
+ Returns
110
+ -------
111
+ probs: torch.tensor [batch_size x seq_length x num_classes] if self.sequential else [batch_size x num_classes]
112
+ probabilities of classes
113
+ """
114
+ #-----------------
115
+ # Encoding graphs
116
+ #-----------------
117
+ if self.is_sequential:
118
+ shp = inputs["curr_node_id"].size()
119
+ inputs = {key: value.flatten(0, 1) for key, value in inputs.items()}
120
+ node_emb = self.node_enc(inputs) # [(batch_size*max_seq_length) x emb_dim] if self.sequential else [batch_size x emb_dim]
121
+ graph_emb = self.readout(inputs, node_emb)
122
+ if self.is_sequential:
123
+ graph_emb = graph_emb.view(*shp, -1) # [batch_size x max_seq_length x emb_dim]
124
+
125
+ #----------
126
+ # Decoding
127
+ #----------
128
+ probs = self.decoder(graph_emb)
129
+
130
+ return probs
131
+
132
+ def get_inputs(self, routes, first_explained_step, node_feats):
133
+ node_feats_ = node_feats.copy()
134
+ node_feats_["tour"] = routes
135
+ if self.problem == "tsptw":
136
+ seq_data = load_tsptw_sequentially(node_feats_)
137
+ elif self.problem == "pctsp":
138
+ seq_data = load_pctsp_sequentially(node_feats_)
139
+ elif self.problem == "pctsptw":
140
+ seq_data = load_pctsptw_sequentially(node_feats_)
141
+ elif self.problem == "cvrp":
142
+ seq_data = load_cvrp_sequentially(node_feats_)
143
+ else:
144
+ NotImplementedError
145
+
146
+ def pad_seq_length(batch):
147
+ data = {}
148
+ for key in batch[0].keys():
149
+ padding_value = True if key == "mask" else 0.0
150
+ # post-padding
151
+ data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
152
+ pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
153
+ data.update({"pad_mask": pad_mask})
154
+ return data
155
+ instance = pad_seq_length(seq_data)
156
+ return instance
models/classifiers/predictor.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ #------------
8
+ # base class
9
+ #------------
10
+ class DecisionPredictorBase(nn.Module):
11
+ def __init__(self, coord_dim, node_dim, state_dim, emb_dim, num_mlp_layers, num_classes, dropout):
12
+ super().__init__()
13
+ self.coord_dim = coord_dim
14
+ self.node_dim = node_dim
15
+ self.emb_dim = emb_dim
16
+ self.state_dim = state_dim
17
+ self.num_mlp_layers = num_mlp_layers
18
+ self.norm_factor = 1 / math.sqrt(emb_dim)
19
+
20
+ # initial embedding
21
+ self.init_linear_node = nn.Linear(node_dim, emb_dim)
22
+ self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
23
+ if state_dim > 0:
24
+ self.init_linear_state = nn.Linear(state_dim, emb_dim)
25
+
26
+ # An attention layer
27
+ self.w_q = nn.Parameter(torch.FloatTensor((2 + int(state_dim > 0)) * emb_dim, emb_dim))
28
+ self.w_k = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
29
+ self.w_v = nn.Parameter(torch.FloatTensor(2 * emb_dim, emb_dim))
30
+
31
+ # MLP
32
+ self.mlp = nn.ModuleList()
33
+ for i in range(self.num_mlp_layers):
34
+ self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
35
+ self.mlp.append(nn.Linear(emb_dim, num_classes, bias=True))
36
+
37
+ # Dropout
38
+ self.dropout = nn.Dropout(dropout)
39
+
40
+ self.reset_parameters()
41
+
42
+ def reset_parameters(self):
43
+ for param in self.parameters():
44
+ stdv = 1. / math.sqrt(param.size(-1))
45
+ param.data.uniform_(-stdv, stdv)
46
+
47
+ def forward(self, inputs):
48
+ """
49
+ Paramters
50
+ ---------
51
+ inputs: dict
52
+ curr_node_id: torch.LongTensor [batch_size]
53
+ next_node_id: torch.LongTensor [batch_size]
54
+ node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
55
+ mask: torch.LongTensor [batch_size x num_nodes]
56
+ state: torch.FloatTensor [batch_size x state_dim]
57
+
58
+ Returns
59
+ -------
60
+ probs: torch.tensor [batch_size x num_classes]
61
+ """
62
+ #----------------
63
+ # input features
64
+ #----------------
65
+ curr_node_id = inputs["curr_node_id"]
66
+ next_node_id = inputs["next_node_id"]
67
+ node_feat = inputs["node_feats"]
68
+ mask = inputs["mask"]
69
+ state = inputs["state"]
70
+
71
+ #---------------------------
72
+ # initial linear projection
73
+ #---------------------------
74
+ node_emb = self.init_linear_node(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
75
+ depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
76
+ new_node_feat = torch.cat((depot_emb, node_emb), 1) # [batch_size x num_nodes x emb_dim]
77
+ new_node_feat = self.dropout(new_node_feat)
78
+
79
+ #---------------
80
+ # preprocessing
81
+ #---------------
82
+ batch_size = curr_node_id.size(0)
83
+ curr_emb = new_node_feat.gather(1, curr_node_id.unsqueeze(-1).expand(batch_size, 1, self.emb_dim))
84
+ next_emb = new_node_feat.gather(1, next_node_id.unsqueeze(-1).expand(batch_size, 1, self.emb_dim))
85
+ if state is not None and self.state_dim > 0:
86
+ state_emb = self.init_linear_state(state) # [batch_size x emb_dim]
87
+ input_q = torch.cat((curr_emb, next_emb, state_emb[:, None, :]), -1) # [batch_size x 1 x (3*emb_dim)]
88
+ else:
89
+ input_q = torch.cat((curr_emb, next_emb), -1) # [batch_size x 1 x (2*emb_dim)]
90
+ input_kv = torch.cat((curr_emb.expand_as(new_node_feat), new_node_feat), -1) # [batch_size x num_nodes x (2*emb_dim)]
91
+
92
+ #--------------------
93
+ # An attention layer
94
+ #--------------------
95
+ q = torch.matmul(input_q, self.w_q) # [batch_size x 1 x emb_dim]
96
+ k = torch.matmul(input_kv, self.w_k) # [batch_size x num_nodes x emb_dim]
97
+ v = torch.matmul(input_kv, self.w_v) # [batch_size x num_nodes x emb_dim]
98
+ compatibility = self.norm_factor * torch.matmul(q, k.transpose(-2, -1)) # [batch_size x 1 x num_nodes]
99
+ compatibility[(~mask).unsqueeze(1).expand_as(compatibility)] = -math.inf
100
+ attn = torch.softmax(compatibility, dim=-1)
101
+ h = torch.matmul(attn, v) # [batch_size x 1 x emb_dim]
102
+ h = h.squeeze(1) # [batch_size x emb_dim]
103
+
104
+ #---------------
105
+ # MLP (decoder)
106
+ #---------------
107
+ for i in range(self.num_mlp_layers):
108
+ h = self.dropout(h)
109
+ h = torch.relu(self.mlp[i](h))
110
+ h = self.dropout(h)
111
+ logits = self.mlp[-1](h)
112
+ probs = F.log_softmax(logits, dim=-1)
113
+ return probs
114
+
115
+ def get_inputs(self, tour, first_explained_step, node_feats):
116
+ """
117
+ For TSPTW
118
+ TODO: refactoring
119
+
120
+ Parameters
121
+ ----------
122
+ tour: list [seq_length]
123
+ first_explained_step: int
124
+ node_feats np.array [num_nodes x node_dim]
125
+
126
+ Returns
127
+ -------
128
+ out: dict (key: data type [data_size])
129
+ curr_node_id: torch.tensor [num_explained_paths]
130
+ next_node_id: torch.tensor [num_explained_paths]
131
+ node_feats: torch.tensor [num_explained_paths x num_nodes x node_dim]
132
+ mask: torch.tensor [num_explained_paths x num_nodes]
133
+ state: torch.tensor [num_explained_paths x state_dim]
134
+ """
135
+ node_feats = {
136
+ key: torch.from_numpy(node_feat.astype(np.float32).copy()).clone()
137
+ if isinstance(node_feat, np.ndarray) else
138
+ torch.tensor([node_feat])
139
+ for key, node_feat in node_feats.items()
140
+ }
141
+ if isinstance(tour, np.ndarray):
142
+ tour = torch.from_numpy(tour.astype(np.long).copy()).clone()
143
+ else:
144
+ tour = torch.LongTensor(tour)
145
+
146
+ out = {"curr_node_id": [], "next_node_id": [], "mask": [], "state": []}
147
+ for step in range(first_explained_step, len(tour) - 1):
148
+ # node ids
149
+ curr_node_id = tour[step]
150
+ next_node_id = tour[step + 1]
151
+ # mask & state
152
+ max_coord = node_feats["grid_size"]
153
+ coord = node_feats["coords"] / max_coord # [num_nodes x coord_dim]
154
+ time_window = node_feats["time_window"] # [num_nodes x 2(start, end)]
155
+ time_window = (time_window - time_window[1:].min()) / time_window[1:].max() # min-max normalization
156
+ curr_time = torch.FloatTensor([0.0])
157
+ raw_coord = node_feats["coords"]
158
+ raw_time_window = node_feats["time_window"]
159
+ raw_curr_time = torch.FloatTensor([0.0])
160
+ num_nodes = len(node_feats["coords"])
161
+ mask = torch.ones(num_nodes, dtype=torch.long) # feasible -> 1, infeasible -> 0
162
+ for i in range(step + 1):
163
+ curr_id = tour[i]
164
+ if i > 0:
165
+ prev_id = tour[i - 1]
166
+ raw_curr_time += torch.norm(raw_coord[curr_id] - raw_coord[prev_id])
167
+ curr_time += torch.norm(coord[curr_id] - coord[prev_id])
168
+ # visited?
169
+ mask[curr_id] = 0
170
+ # curr_time exceeds the time window?
171
+ mask[curr_time > time_window[:, 1]] = 0
172
+ curr_time = (raw_curr_time - raw_time_window[1:].min()) / raw_time_window[1:].max() # min-max normalization
173
+ out["curr_node_id"].append(curr_node_id)
174
+ out["next_node_id"].append(next_node_id)
175
+ out["mask"].append(mask)
176
+ out["state"].append(curr_time)
177
+ out = {key: torch.stack(value, 0) for key, value in out.items()}
178
+ node_feats = {
179
+ key: node_feat.unsqueeze(0).expand(out["mask"].size(0), *node_feat.size())
180
+ for key, node_feat in node_feats.items()
181
+ }
182
+ out.update({"node_feats": node_feats})
183
+ return out
184
+
185
+ #---------------
186
+ # general class
187
+ #---------------
188
+ class DecisionPredictor(DecisionPredictorBase):
189
+ def __init__(self, problem, emb_dim, num_mlp_layers, num_classes, drop):
190
+ coord_dim = 2
191
+ self.problem = problem
192
+ if problem == "tsptw":
193
+ node_dim = coord_dim + 2 # + time_window(start, end)
194
+ state_dim = 1 # current_time
195
+ elif problem == "cvrp":
196
+ node_dim = coord_dim + 1 # + demand
197
+ state_dim = 1 # used_capacity
198
+ elif problem == "cvrptw":
199
+ node_dim = coord_dim + 1 + 2 # + demand + time_window(start, end)
200
+ state_dim = 2 # used_capacity + current_time
201
+ else:
202
+ assert False, f"problem {problem} is not supported!"
203
+ super().__init__(coord_dim, node_dim, state_dim, emb_dim, num_mlp_layers, num_classes, drop)
models/classifiers/rule_based_models.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ TOUR_LENGTH = 0
5
+ TIME_WINDOW = 1
6
+
7
+ class kNearestPredictor(nn.Module):
8
+ def __init__(self, problem, k, k_type):
9
+ """
10
+ Paramters
11
+ ---------
12
+ problem: str
13
+ problem type
14
+ k: float
15
+ if the vehicle visis k% nearest node, this model labels the visit as prioritizing tour length
16
+ """
17
+ super().__init__()
18
+ self.problem = problem
19
+ self.num_classes = 2
20
+ self.k_type = k_type
21
+ if k_type == "num":
22
+ self.k = int(k)
23
+ elif k_type == "ratio":
24
+ self.k = k
25
+ else:
26
+ assert False, "Invalid k_type. select from [num, ratio]"
27
+
28
+ def forward(self, inputs):
29
+ """
30
+ Parameters
31
+ ----------
32
+
33
+ Returns
34
+ -------
35
+ probs: torch.tensor [batch_size x num_classes]
36
+ """
37
+ #----------------
38
+ # input features
39
+ #----------------
40
+ curr_node_id = inputs["curr_node_id"]
41
+ next_node_id = inputs["next_node_id"]
42
+ node_feat = inputs["node_feats"]
43
+ mask = inputs["mask"]
44
+
45
+ coord_dim = 2
46
+ batch_size = curr_node_id.size(0)
47
+ coords = node_feat[:, :, :coord_dim] # [batch_size x num_nodes x coord_dim]
48
+ num_candidates = (mask > 0).sum(dim=-1) # [batch_size]
49
+ topk = torch.round(num_candidates * self.k).to(torch.long) # [batch_size]
50
+ curr_coord = coords.gather(1, curr_node_id[:, None, None].expand_as(coords)) # [batch_size x 1 x coord_dim]
51
+ dist_from_curr_node = torch.norm(curr_coord - coords, dim=-1) # [batch_size x 1 x num_nodes]
52
+ visit_topk = []
53
+ for i in range(batch_size):
54
+ if self.k_type == "num":
55
+ k = self.k
56
+ else:
57
+ k = topk[i].item()
58
+ id = torch.topk(input=dist_from_curr_node[i], k=k, dim=-1, largest=True)[1]
59
+ visit_topk.append(torch.isin(next_node_id[i], id))
60
+ visit_topk = torch.stack(visit_topk, 0)
61
+ idx = (1 - visit_topk.int()).to(torch.long)
62
+ probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
63
+ probs.scatter_(-1, idx.unsqueeze(-1).expand_as(probs), 1.0)
64
+ return probs
65
+
66
+ def get_inputs(self, tour, first_explained_step, node_feats):
67
+ """
68
+ For TSPTW
69
+ TODO: refactoring
70
+
71
+ Parameters
72
+ ----------
73
+ tour: list [seq_length]
74
+ first_explained_step: int
75
+ node_feats np.array [num_nodes x node_dim]
76
+
77
+ Returns
78
+ -------
79
+ out: dict (key: data type [data_size])
80
+ curr_node_id: torch.tensor [num_explained_paths]
81
+ next_node_id: torch.tensor [num_explained_paths]
82
+ node_feats: torch.tensor [num_explained_paths x num_nodes x node_dim]
83
+ mask: torch.tensor [num_explained_paths x num_nodes]
84
+ state: torch.tensor [num_explained_paths x state_dim]
85
+ """
86
+ if isinstance(node_feats, np.ndarray):
87
+ node_feats = torch.from_numpy(node_feats.astype(np.float32)).clone()
88
+ tour = torch.LongTensor(tour)
89
+ coord_dim = 2
90
+ out = {"curr_node_id": [], "next_node_id": [], "mask": [], "state": []}
91
+ for step in range(first_explained_step, len(tour) - 1):
92
+ # node ids
93
+ curr_node_id = tour[step]
94
+ next_node_id = tour[step + 1]
95
+ # mask & state
96
+ max_coord = 100
97
+ coord = node_feats[:, coord_dim] / max_coord # [num_nodes x coord_dim]
98
+ time_window = node_feats[:, coord_dim:] # [num_nodes x 2(start, end)]
99
+ time_window = (time_window - time_window[1:].min()) / time_window[1:].max() # min-max normalization
100
+ curr_time = torch.FloatTensor([0.0])
101
+ raw_coord = node_feats[:, coord_dim]
102
+ raw_time_window = node_feats[:, coord_dim:]
103
+ raw_curr_time = torch.FloatTensor([0.0])
104
+ mask = torch.ones(node_feats.size(0), dtype=torch.long) # feasible -> 1, infeasible -> 0
105
+ for i in range(step + 1):
106
+ curr_id = tour[i]
107
+ if i > 0:
108
+ prev_id = tour[i - 1]
109
+ raw_curr_time += torch.norm(raw_coord[curr_id] - raw_coord[prev_id])
110
+ curr_time += torch.norm(coord[curr_id] - coord[prev_id])
111
+ # visited?
112
+ mask[curr_id] = 0
113
+ # curr_time exceeds the time window?
114
+ mask[curr_time > time_window[:, 1]] = 0
115
+ curr_time = (raw_curr_time - raw_time_window[1:].min()) / raw_time_window[1:].max() # min-max normalization
116
+ out["curr_node_id"].append(curr_node_id)
117
+ out["next_node_id"].append(next_node_id)
118
+ out["mask"].append(mask)
119
+ out["state"].append(curr_time)
120
+ out = {key: torch.stack(value, 0) for key, value in out.items()}
121
+ node_feats = node_feats.unsqueeze(0).expand(out["mask"].size(0), node_feats.size(-2), node_feats.size(-1))
122
+ out.update({"node_feats": node_feats})
123
+ return out
124
+
125
+ def get_topk_ids(self, input, k, dim, largest):
126
+ """
127
+ Parameters
128
+ ----------
129
+ input: torch.tensor [batch_size x num_nodes x num_nodes]
130
+ k: torch.tensor [batch_size]
131
+ dim: int
132
+ largest: bool
133
+
134
+ Returns
135
+ -------
136
+ topk_ids: torch.tensor [batch_size x num_node x k]
137
+ """
138
+ batch_size = input.size(0)
139
+ max_k = k.max()
140
+ ids = []
141
+ for i in range(batch_size):
142
+ id = torch.topk(input=input[i], k=k[i].item(), dim=dim, largest=largest)[1]
143
+
144
+ # adjust tensor size
145
+ if id.size(0) == 0:
146
+ id = torch.full((max_k, ), -1000)
147
+ elif id.size(0) < max_k:
148
+ id = torch.cat((id, torch.full((max_k - id.size(0), ), id[0])), -1)
149
+ ids.append(id)
150
+ return torch.stack(ids, 0)
models/loss_functions.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils.utils import batched_bincount
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class GeneralCrossEntropy(nn.Module):
8
+ def __init__(self, weight_type: str, beta : float = 0.99, is_sequential: bool = True):
9
+ super().__init__()
10
+ self.weight_type = weight_type
11
+ self.beta = beta
12
+ if weight_type == "seq_cbce":
13
+ assert is_sequential == True
14
+ self.loss_func = SeqCBCrossEntropy(beta=beta)
15
+ elif weight_type == "cbce":
16
+ self.loss_func = CBCrossEntropy(beta=beta, is_sequential=is_sequential)
17
+ elif weight_type == "wce":
18
+ self.loss_func = WeightedCrossEntropy(is_sequential=is_sequential)
19
+ elif weight_type == "ce":
20
+ self.loss_func = CrossEntropy(is_sequential=is_sequential)
21
+ else:
22
+ NotImplementedError
23
+
24
+ def forward(self,
25
+ preds: torch.Tensor,
26
+ labels: torch.Tensor,
27
+ pad_mask: torch.Tensor = None):
28
+ return self.loss_func(preds, labels, pad_mask)
29
+
30
+
31
+ class SeqCBCrossEntropy(nn.Module):
32
+ def __init__(self, beta : float = 0.99):
33
+ super().__init__()
34
+ self.beta = beta
35
+
36
+ def forward(self,
37
+ preds: torch.Tensor,
38
+ labels: torch.Tensor,
39
+ pad_mask: torch.Tensor):
40
+ """
41
+ Sequential Class-alanced Cross Entropy Loss (Our proposal)
42
+
43
+ Parameters
44
+ -----------
45
+ preds: torch.Tensor [batch_size, max_seq_length, num_classes]
46
+ labels: torch.Tensor [batch_size, max_seq_length]
47
+ pad_mask: torch.Tensor [batch_size, max_seq_length]
48
+
49
+ Returns
50
+ -------
51
+ loss: torch.Tensor [1]
52
+ """
53
+ seq_length_batch = pad_mask.sum(-1) # [batch_size]
54
+ seq_length_list = torch.unique(seq_length_batch) # [num_unique_seq_length]
55
+ batch_size = preds.size(0)
56
+ loss = 0
57
+ for seq_length in seq_length_list:
58
+ extracted_batch = (seq_length_batch == seq_length) # [batch_size]
59
+ extracted_preds = preds[extracted_batch] # [num_extracted_batch]
60
+ extracted_labels = labels[extracted_batch] # [num_extracted_batch]
61
+ extracted_batch_size = extracted_labels.size(0)
62
+ bin = batched_bincount(extracted_labels.T, 1, extracted_preds.size(-1)) # [seq_length x num_classes]
63
+ weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
64
+ for seq_no in range(seq_length.item()):
65
+ loss += (extracted_batch_size / batch_size) * F.nll_loss(extracted_preds[:, seq_no], extracted_labels[:, seq_no], weight=weight[seq_no])
66
+ return loss
67
+
68
+ class CBCrossEntropy(nn.Module):
69
+ def __init__(self, beta : float = 0.99, is_sequential: bool = True):
70
+ super().__init__()
71
+ self.beta = beta
72
+ self.is_sequential = is_sequential
73
+
74
+ def forward(self,
75
+ preds: torch.Tensor,
76
+ labels: torch.Tensor,
77
+ pad_mask: torch.Tensor = None):
78
+ if self.is_sequential:
79
+ mask = pad_mask.view(-1)
80
+ preds = preds.view(-1, preds.size(-1))
81
+ bin = labels.view(-1)[mask].bincount()
82
+ weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
83
+ loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
84
+ else:
85
+ bincount = labels.view(-1).bincount()
86
+ weight = (1 - self.beta) / (1 - self.beta**bincount + 1e-8)
87
+ loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
88
+ return loss
89
+
90
+ class WeightedCrossEntropy(nn.Module):
91
+ def __init__(self, is_sequential: bool = True, norm: str = "min"):
92
+ super().__init__()
93
+ self.is_sequential = is_sequential
94
+ if norm == "min":
95
+ self.norm = torch.min
96
+ elif norm == "max":
97
+ self.norm = torch.max
98
+ def forward(self,
99
+ preds: torch.Tensor,
100
+ labels: torch.Tensor,
101
+ pad_mask: torch.Tensor = None):
102
+ if self.is_sequential:
103
+ mask = pad_mask.view(-1)
104
+ preds = preds.view(-1, preds.size(-1))
105
+ bin = labels.view(-1)[mask].bincount()
106
+ weight = self.norm(bin) / (bin + 1e-8)
107
+ loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
108
+ else:
109
+ bincount = labels.view(-1).bincount()
110
+ weight = self.norm(bin) / (bin + 1e-8)
111
+ loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
112
+ return loss
113
+
114
+ class CrossEntropy(nn.Module):
115
+ def __init__(self, is_sequential: bool = True):
116
+ super().__init__()
117
+ self.is_sequential = is_sequential
118
+
119
+ def forward(self,
120
+ preds: torch.Tensor,
121
+ labels: torch.Tensor,
122
+ pad_mask: torch.Tensor = None):
123
+ if self.is_sequential:
124
+ mask = pad_mask.view(-1)
125
+ preds = preds.view(-1, preds.size(-1))
126
+ loss = F.nll_loss(preds[mask], labels.view(-1)[mask])
127
+ else:
128
+ loss = F.nll_loss(preds, labels.squeeze(-1))
129
+ return loss
models/prompts/generate_explanation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from langchain.schema import StrOutputParser
3
+ from langchain_core.runnables.base import Runnable
4
+ from models.prompts.template_json_base import TemplateJsonBase
5
+
6
+ GENERATE_EXPLANATION = """\
7
+ You are RouteExplainer, an explanation system for justifying a specific edge (i.e., actual edge) in a route automatically generated by a VRP solver.
8
+ Here, you address a scenario where a tourist (user) wonders why the actual edge was selected at the step in the tourist route and why another edge was not selected instead at that step.
9
+ As an expert tour guide, you will justify why the actual edge was selected in the route if it outperforms another edge. That helps to convince the tourist of the actual edge or to make the tourist's decision to change to another edge from the actual edge while accepting some disadvantages.
10
+ Please carefully read the contents below and follow the instructions faithfully.
11
+
12
+ [Terminology]
13
+ The following terms are used here.
14
+ - Node: A destination.
15
+ - Edge: A directed edge representing the movement from one node to another.
16
+ - Edge intention: The underlying purpose of the edge. An edge intention here is either “prioritizing route length (route_len)” or “prioritizing time windows (time_window)”.
17
+ - Step: The visited order in a route.
18
+ - Actual edge: A user-specified edge in the optimal route generated by a VRP solver. You will justify this edge in this task.
19
+ - Counterfactual (CF) edge: A user-specified edge that was not selected at the step of the actual edge in the optimal route but could have been. This is a different edge from the actual edge. The user wonders why the CF edge was not selected at the step instead of the actual edge.
20
+ - Actual route: The optimal route generated by a VRP solver.
21
+ - CF route: An alternative route where the CF edge is selected at the step instead of the actual edge. The subsequent edges to the CF edge in the CF route are the best-effort ones.
22
+
23
+ [Example]
24
+ Please refer to the following input-output example when generating a counterfactual explanation.
25
+ ***** START EXAMPLE *****
26
+ [input]
27
+ Question:
28
+ - The question asks about replacing the edge from node2 to node3 with the edge from node2 to node5.
29
+ Actual route:
30
+ - route: node1 > node2 > (actual edge) > node3 > node4 > node5 > node6 > node7 > node1
31
+ - short-term effect (immediate travel time): 20 minutes
32
+ - long-term effect (total travel time): 100 minutes
33
+ - missed nodes: none
34
+ - edge-intention ratio after the actual edge: time_window 75%, route_len 25%
35
+ CF route:
36
+ - route: node1 > node2 > (CF edge) > node5 > node6 > node7 > node1
37
+ - short-term effect (immediate travel time): 10 minutes
38
+ - long-term effect (total travel time): 77.8 minutes
39
+ - missed nodes: node3, node4
40
+ - edge-intention ratio after the CF edge: time_window 100%, route_len 0%
41
+ Difference between two routes:
42
+ - short-term effect: The actual route increases it by 10 minutes
43
+ - long-term effect: The actual route increases it by 22.2 minutes
44
+ - missed nodes: The actual route visits 2 more nodes
45
+ - difference of edge-intention ratio after the actual and CF edges: time_window -25%, route_len +25%
46
+ Planed destination information:
47
+ - node1: start/end point
48
+ - node2: none
49
+ - node3: take lunch
50
+ - node4: attend a tour
51
+ - node5: most favorite destination
52
+ - node6: take dinner
53
+ - node7: none
54
+
55
+ [Explanation]
56
+ Here are the pros and cons of the actual and CF edges.
57
+ #### Actual edge:
58
+ - Pros:
59
+ - It allows you to visit all your destinations within time windows. That is essential for maximizing your tour experience.
60
+ - Cons:
61
+ - Immediate travel time will increase by 10 minutes.
62
+ - The total travel time will increase by 22.2 minutes, but it is natural because the actual route visits two more nodes than the CF route.
63
+ - Remarks:
64
+ - The route balances both prioritizing travel time and time windows.
65
+ #### CF edge:
66
+ - Pros:
67
+ - Immediate travel time will decrease by 10 minutes.
68
+ - The total travel time will decrease by 22.2 minutes. However, note that this reduction in time is the result of not visiting two nodes.
69
+ - Cons:
70
+ - You will miss node3 and node4. You plan to take lunch and attend a tour, so the loss could significantly degrade your tour experience.
71
+ - Remarks:
72
+ - You will miss node3 and node4 even if you are constantly pressed for time windows in the subsequent movement
73
+ #### Summary:
74
+ - Given the pros and cons and the fact that adhering to time constraints is essential, the actual edge is objectively more optimal.
75
+ - However, you might prefer the CF edge, despite its cons, depending on your preferences.
76
+ ***** END EXAMPLE *****
77
+
78
+ [Instruction]
79
+ Now, please generate a counterfactual explanation for the [input] below.
80
+ You MUST keep the following rules:
81
+ - Summarize the pros and cons, including short-term effects, long-term effects, missed nodes, and edge-intention ratio.
82
+ - Enrich explanations by leveraging destination information.
83
+ - Carefully consider causality regarding travel time reduction. If the number of missed nodes is equal, one edge may reduce travel time. However, if a route with missed nodes is quicker, it is due to skipping nodes.
84
+ - A high route_len ratio emphasizes speed over schedule adherence, while a high time_window ratio prioritizes sticking to a schedule, sacrificing travel efficiency for timely arrivals.
85
+ - Disucuss edge-intention ratio in "Remarks". Do NOT do it in "Pros" or "Cons".
86
+ - Travel time efficiency is solely determined by the total travel time.
87
+ - Never say that all planed destinations are visited if there is even one missed node. If some nodes are missed, you must specify which node are missed.
88
+ - If the CF edge outperforms the actual edge, you do NOT have to force a justification for the actual edge.
89
+ - Please associate user's intention that "{intent}" with your summary. If the intention is blank, it means no intention was provided.
90
+
91
+ [input]
92
+ {comparison_results}
93
+
94
+ [Explanation]
95
+ """
96
+
97
+ # - Routes are assessed based on the following priorities: fewer missed nodes are better > shorter total travel time is better.
98
+ # - In "Summary", Clearly and specifically explain the differences between the actual and CF edges to help the tourist convince the actual edge or make a decision to change to another edge from the actual edge while accepting some cons.
99
+ # - Ensure consistency in comparisons: the pros of the actual edge should be the cons of the CF edge and vice versa.
100
+
101
+ class Template4GenerateExplanation(TemplateJsonBase):
102
+ parser: Runnable = StrOutputParser()
103
+ template: str = GENERATE_EXPLANATION
104
+ prompt: Runnable = PromptTemplate(
105
+ template=template,
106
+ input_variables=["comparison_results", "intent"],
107
+ )
108
+
109
+ def _get_output_key(self) -> str:
110
+ return ""
models/prompts/identify_question.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from langchain_core.output_parsers import JsonOutputParser
3
+ from langchain_core.pydantic_v1 import BaseModel, Field
4
+ from langchain_core.runnables.base import Runnable
5
+ from models.prompts.template_json_base import TemplateJsonBase
6
+
7
+ IDENTIFY_QUESTION = """\
8
+ Given a route and a question that asks what would happen if we replaced a specific edge in the route with another edge, please extract the step number of the replaced edge (i.e., cf_step) and the node id of the destination of the new edge (i.e., cf_visit) from the question, which is written in natural language.
9
+ Please use the following examples as a reference when you answer:
10
+ ***** START EXAMPLE *****
11
+ [route info]
12
+ Nodes(node id, name): (1, node1), (2, node2), (3, node3), (4, node4), (5, node5)
13
+ Route: node1 > (step1) > node5 > (step2) > node3 > (step3) > node2 > (step4) > node4 > (step 5) > node1
14
+ [question]
15
+ Why node3, and why not node2?
16
+ [outputs]
17
+ ```json
18
+ {{
19
+ "success": true,
20
+ "summary": "The answer asks about replacing the edge from node5 to node3 with the edge from node6 to node2.",
21
+ "intent": "",
22
+ “process": "The edge from node5 to node3 is at step2 because of "node5 > (step2) > node3". The node id of the destination of the new edge is 2 (node2). Thus, the final answers are cf_step=2 and cf_visit=2.",
23
+ "cf_step": 2,
24
+ "cf_visit": 2,
25
+ }}
26
+ ```
27
+
28
+ [route info]
29
+ Nodes(node id, name): (1, node1), (2, node2), (3, node3), (4, node4), (5, node5)
30
+ Route: node1 > (step1) > node5 > (step2) > node3 > (step3) > node2 > (step4) > node4 > (step 5) > node1
31
+ [quetsion]
32
+ What if we visited node4 instead of node2? We would personally like to visit node4 first.
33
+ [outputs]
34
+ ```json
35
+ {{
36
+ "success": true,
37
+ "summary": "The answer asks about replacing the edge from node3 to node2 with the edge from node3 to node4.",
38
+ “intent": "The user would personally like to visit node4 first""
39
+ "process": "The edge from node3 to node2 is at step3 because of "node3 > (step3) > node2". The node id of the destination of the new edge is 4 (node4). Thus, the final answers are cf_step=3 and cf_visit=4.",
40
+ "cf_step": 3,
41
+ "cf_visit": 4,
42
+ }}
43
+ ```
44
+ ***** END EXAMPLE *****
45
+
46
+ Given the following route and question, please extract the step number of the replaced edge (i.e., cf_step) and the node id of the destination of the new edge (i.e., cf_visit) from the question.
47
+ Please keep the following rules:
48
+ - Do not output any sentences outside of JSON format.
49
+ - {format_instructions}
50
+
51
+ [route_info]
52
+ {route_info}
53
+ [question]
54
+ {whynot_question}
55
+ [outputs]
56
+ """
57
+
58
+ class WhyNotQuestion(BaseModel):
59
+ success: bool = Field(description="Whether cf_step and cf_visit are successfully extracted (True) or not (False).")
60
+ summary: str = Field(description="Your summary for the given question. If success=False, instead state here what information is missing to extract cf_step/cf_visit and what additional information should be clarified (Additionally, provide an example).")
61
+ intent: str = Field(description="Your summary for user's intent (if provided). If not provided, this is set to ''.")
62
+ process: str = Field(description="The thought (reasoning) process in extracting cf_step and cf_visit. if success=False, this is set to ''.")
63
+ cf_step: int = Field(description="The step number of the replaced edge. if success=False, this is set to -1.")
64
+ cf_visit: int = Field(description="The node id of the destination of the new edge. if success=False, this is set to -1.")
65
+
66
+ class Template4IdentifyQuestion(TemplateJsonBase):
67
+ parser: Runnable = JsonOutputParser(pydantic_object=WhyNotQuestion)
68
+ template: str = IDENTIFY_QUESTION
69
+ prompt: Runnable = PromptTemplate(
70
+ template=template,
71
+ input_variables=["whynot_question", "route_info"],
72
+ partial_variables={"format_instructions": parser.get_format_instructions()}
73
+ )
74
+
75
+ def _get_output_key(self) -> str:
76
+ return ""
models/prompts/template_json_base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any
3
+ from langchain_core.runnables import RunnableLambda
4
+ from langchain_core.runnables.base import Runnable
5
+
6
+ class TemplateJsonBase(ABC):
7
+ parser: Runnable
8
+ template: str
9
+ prompt: Runnable
10
+
11
+ @abstractmethod
12
+ def _get_output_key(self) -> str:
13
+ raise NotImplementedError
14
+
15
+ def get_template(self) -> str:
16
+ return self.template
17
+
18
+ def extract_value(self, input: Dict[str, Any]) -> Any:
19
+ return input[self._get_output_key()]
20
+
21
+ def sandwiches(self,
22
+ llm: Runnable,
23
+ extract_value: bool = False) -> Runnable:
24
+ if extract_value:
25
+ return self.prompt | llm | self.parser | RunnableLambda(self.extract_value)
26
+ else:
27
+ return self.prompt | llm | self.parser
models/route_explainer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # templates
2
+ import numpy as np
3
+ import streamlit as st
4
+ from typing import Dict, List
5
+ from models.prompts.identify_question import Template4IdentifyQuestion
6
+ from models.prompts.generate_explanation import Template4GenerateExplanation
7
+ from langchain.callbacks.base import BaseCallbackHandler
8
+ from langchain.schema import AIMessage
9
+ import utils.util_app as util_app
10
+
11
+ class StreamingChatCallbackHandler(BaseCallbackHandler):
12
+ def __init__(self):
13
+ pass
14
+
15
+ def on_llm_start(self, *args, **kwargs):
16
+ self.container = st.empty()
17
+ self.text = ""
18
+
19
+ def on_llm_new_token(self, token: str, *args, **kwargs):
20
+ self.text += token
21
+ self.container.markdown(
22
+ body=self.text,
23
+ unsafe_allow_html=False,
24
+ )
25
+
26
+ def on_llm_end(self, response: str, *args, **kwargs):
27
+ self.container.markdown(
28
+ body=response.generations[0][0].text,
29
+ unsafe_allow_html=False,
30
+ )
31
+
32
+ class RouteExplainer():
33
+ template_identify_question = Template4IdentifyQuestion()
34
+ template_generate_explanation = Template4GenerateExplanation()
35
+
36
+ def __init__(self,
37
+ llm,
38
+ cf_generator,
39
+ classifier) -> None:
40
+ assert cf_generator.problem == classifier.problem, "Problem type of cf_generator and predictor should coincide!"
41
+ self.coord_dim = 2
42
+ self.problem = cf_generator.problem
43
+ self.cf_generator = cf_generator
44
+ self.classifier = classifier
45
+ self.actual_route = None
46
+ self.cf_route = None
47
+ # templates
48
+ self.question_extractor = self.template_identify_question.sandwiches(llm)
49
+ self.explanation_generator = self.template_generate_explanation.sandwiches(llm)
50
+
51
+ #----------------
52
+ # whole pipeline
53
+ #----------------
54
+ def generate_explanation(self,
55
+ tour_list,
56
+ whynot_question: str,
57
+ actual_routes: list,
58
+ actual_labels: list,
59
+ node_feats: dict,
60
+ dist_matrix: np.array) -> str:
61
+ #--------------------------------
62
+ # define why & why-not questions
63
+ #--------------------------------
64
+ route_info_text = self.get_route_info_text(tour_list, actual_routes)
65
+ inputs = self.question_extractor.invoke({
66
+ "whynot_question": whynot_question,
67
+ "route_info": route_info_text
68
+ })
69
+ util_app.stream_words(inputs["summary"] + " " + inputs["intent"])
70
+ st.session_state.chat_history.append(AIMessage(content=inputs["summary"] + inputs["intent"]))
71
+ if not inputs["success"]:
72
+ return ""
73
+
74
+ #----------------------
75
+ # validate the CF edge
76
+ #----------------------
77
+ is_cf_edge_feasible, reason = self.validate_cf_edge(node_feats,
78
+ dist_matrix,
79
+ actual_routes[0],
80
+ inputs["cf_step"],
81
+ inputs["cf_visit"]-1)
82
+ # exception
83
+ if not is_cf_edge_feasible:
84
+ util_app.stream_words(reason)
85
+ return reason
86
+
87
+ #---------------------
88
+ # generate a cf route
89
+ #---------------------
90
+ cf_routes = self.cf_generator(actual_routes,
91
+ vehicle_id=0,
92
+ cf_step=inputs["cf_step"],
93
+ cf_next_node_id=inputs["cf_visit"]-1,
94
+ node_feats=node_feats,
95
+ dist_matrix=dist_matrix)
96
+ st.session_state.generated_cf_route = True
97
+ st.session_state.close_chat = True
98
+ st.session_state.cf_step = inputs["cf_step"]
99
+
100
+ #--------------------------------------
101
+ # classify the intentions of each edge
102
+ #--------------------------------------
103
+ cf_labels = self.classifier(self.classifier.get_inputs(cf_routes,
104
+ 0,
105
+ node_feats,
106
+ dist_matrix))
107
+ st.session_state.cf_routes = cf_routes
108
+ st.session_state.cf_labels = cf_labels
109
+
110
+ #-------------------------------------
111
+ # generate a constrastive explanation
112
+ #-------------------------------------
113
+ comparison_results = self.get_comparison_results(question_summary=inputs["summary"],
114
+ tour_list=tour_list,
115
+ actual_routes=actual_routes,
116
+ actual_labels=actual_labels,
117
+ cf_routes=cf_routes,
118
+ cf_labels=cf_labels,
119
+ cf_step=inputs["cf_step"])
120
+
121
+ explanation = self.explanation_generator.invoke({
122
+ "comparison_results": comparison_results,
123
+ "intent": inputs["intent"]
124
+ }, config={"callbacks": [StreamingChatCallbackHandler()]})
125
+
126
+ return explanation
127
+
128
+ #-------------------------
129
+ # for exctracting inputs
130
+ #-------------------------
131
+ def get_route_info_text(self, tour_list, routes) -> str:
132
+ route_info = ""
133
+ # nodes
134
+ route_info += "Nodes(node id, name): "
135
+ for i, destination in enumerate(tour_list):
136
+ if i != len(tour_list) - 1:
137
+ route_info += f"({i+1}, {destination['name']}), "
138
+ else:
139
+ route_info += f"({i+1}, {destination['name']})\n"
140
+
141
+ # routes
142
+ route_info += "Route: "
143
+ for i, node_id in enumerate(routes[0]):
144
+ if i == 0:
145
+ route_info += f"{tour_list[node_id]['name']} "
146
+ else:
147
+ route_info += f"> (step {i}) > {tour_list[node_id]['name']})"
148
+ if i == len(routes[0]) - 1:
149
+ route_info += "\n"
150
+ else:
151
+ route_info += " "
152
+ return route_info
153
+
154
+ #--------------------------
155
+ # for validating a CF edge
156
+ #--------------------------
157
+ def validate_cf_edge(self,
158
+ node_feats: Dict[str, np.array],
159
+ dist_matrix: np.array,
160
+ route: List[int],
161
+ cf_step: int,
162
+ cf_visit: int) -> bool:
163
+ # calc current time
164
+ curr_time = node_feats["time_window"][route[0]][0] # start point's open time
165
+ for step in range(1, cf_step):
166
+ curr_node_id = route[step-1]
167
+ next_node_id = route[step]
168
+ curr_time += node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
169
+ curr_time = max(curr_time, node_feats["time_window"][next_node_id][0]) # waiting
170
+
171
+ # validate the cf edge
172
+ curr_node_id = route[cf_step-1]
173
+ next_node_id = cf_visit
174
+ next_node_close_time = node_feats["time_window"][next_node_id][1]
175
+ arrival_time = curr_time + node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
176
+ if next_node_close_time < arrival_time:
177
+ exceed_time = (arrival_time - next_node_close_time)
178
+ return False, f"Oops, your CF edge is infeasible because it does not meet the destination's close time by {util_app.add_time_unit(exceed_time)}."
179
+ else:
180
+ return True, "The CF edge is feasible!"
181
+
182
+ #-------------------------------
183
+ # for generating an explanation
184
+ #-------------------------------
185
+ def get_comparison_results(self,
186
+ tour_list,
187
+ question_summary,
188
+ actual_routes: List[List[int]],
189
+ actual_labels: List[List[int]],
190
+ cf_routes: List[List[int]],
191
+ cf_labels: List[List[int]],
192
+ cf_step: int) -> str:
193
+ comparison_results = "Question:\n" + question_summary + "\n"
194
+ comparison_results += "Actual route:\n" + \
195
+ self.get_route_info(tour_list, actual_routes[0], actual_labels[0], cf_step-1, "actual") + \
196
+ self.get_representative_values(actual_routes[0], actual_labels[0], cf_step-1, "actual")
197
+ comparison_results += "CF route:\n" + \
198
+ self.get_route_info(tour_list, cf_routes[0], cf_labels[0], cf_step-1, "CF") + \
199
+ self.get_representative_values(cf_routes[0], cf_labels[0], cf_step-1, "CF")
200
+ comparison_results += "Difference between two routes:\n" + self.get_diff(cf_step-1, actual_routes[0], cf_routes[0])
201
+ comparison_results += "Planed desination information:\n" + self.get_node_info()
202
+ return comparison_results
203
+
204
+ def get_route_info(self,
205
+ tour_list,
206
+ route: List[int],
207
+ label: List[int],
208
+ ex_step: int,
209
+ type: str) -> str:
210
+ def get_labelname(label_number):
211
+ return "route_len" if label_number == 0 else "time_window"
212
+ route_info = "- route: "
213
+ for i, node_id in enumerate(route):
214
+ if i == ex_step and i != len(route) - 1:
215
+ if type == "actual":
216
+ edge_label = {get_labelname(label[i])}
217
+ else:
218
+ edge_label = "user_preference"
219
+ route_info += f"{tour_list[node_id]['name']} > ({type} edge: {edge_label}) > "
220
+ elif i != len(route) - 1:
221
+ route_info += f"{tour_list[node_id]['name']} > ({get_labelname(label[i])}) > "
222
+ else:
223
+ route_info += f"{tour_list[node_id]['name']}\n"
224
+ return route_info
225
+
226
+ def get_representative_values(self, route, labels, ex_step, type) -> str:
227
+ time_window_ratio = self.get_intention_ratio(1, labels, ex_step) * 100
228
+ route_len_ratio = self.get_intention_ratio(0, labels, ex_step) * 100
229
+ return f"- short-term effect (immediate travel time): {self.get_immediate_state(route, ex_step)//60} minutes\n- long-term effect (total travel time): {self.get_route_length(route)//60} minutes\n- missed nodes: {self.get_infeasible_node_name(route)}\n- edge-intention ratio after the {type} edge: time_window {time_window_ratio: .1f}%, route_len {route_len_ratio: .1f}%"
230
+
231
+ def get_immediate_state(self, route, ex_step) -> str:
232
+ return st.session_state.dist_matrix[route[ex_step]][route[ex_step+1]]
233
+
234
+ def get_route_length(self, route) -> float:
235
+ route_length = 0.0
236
+ for i in range(len(route)-1):
237
+ route_length += st.session_state.dist_matrix[route[i]][route[i+1]]
238
+ return route_length
239
+
240
+ def get_infeasible_nodes(self, route) -> int:
241
+ return len(route) - (len(st.session_state.dist_matrix) - 1)
242
+
243
+ def get_infeasible_node_name(self, route) -> str:
244
+ if len(route) == len(st.session_state.dist_matrix) - 1:
245
+ return "none"
246
+ else:
247
+ num_nodes = np.arange(len(st.session_state.dist_matrix))
248
+ for node_id in route:
249
+ num_nodes = num_nodes[num_nodes != node_id]
250
+ return ",".join([st.session_state.tour_list[node_id]["name"] for node_id in num_nodes])
251
+
252
+ def get_intention_ratio(self,
253
+ intention: int,
254
+ labels: List[int],
255
+ ex_step: int) -> float:
256
+ np_labels = np.array(labels)
257
+ return np.sum(np_labels[ex_step:] == intention) / len(labels[ex_step:])
258
+
259
+ def get_diff(self, ex_step, actual_route, cf_route) -> str:
260
+ def get_str(effect: float):
261
+ long_effect_str = "The actual route increases it by" if effect > 0 else "The actual route reduces it by"
262
+ long_effect_str += util_app.add_time_unit(abs(effect))
263
+ return long_effect_str
264
+
265
+ def get_str2(num_nodes: int, num_missed_nodes):
266
+ if num_nodes < 0:
267
+ num_nodes_str = f"The actual route visits {abs(num_nodes)} more nodes"
268
+ elif num_nodes == 0:
269
+ if num_missed_nodes == 0:
270
+ num_nodes_str = f"Both routes missed no node,"
271
+ else:
272
+ num_nodes_str = f"Both routes missed the same number of nodes ({abs(num_missed_nodes)} node(s))"
273
+ else:
274
+ num_nodes_str = f"The actual route visits {abs(num_nodes)} less nodes"
275
+ return num_nodes_str
276
+
277
+ # short/long-term effects
278
+ short_effect = self.get_immediate_state(actual_route, ex_step) - self.get_immediate_state(cf_route, ex_step)
279
+ long_effect = self.get_route_length(actual_route) - self.get_route_length(cf_route)
280
+ short_effect_str = get_str(short_effect)
281
+ long_effect_str = get_str(long_effect)
282
+
283
+ # missed nodes
284
+ missed_nodes = self.get_infeasible_nodes(actual_route) - self.get_infeasible_nodes(cf_route)
285
+ missed_nodes_str = get_str2(missed_nodes, self.get_infeasible_nodes(actual_route))
286
+
287
+ return f"- short-term effect: {short_effect_str}\n - long-term effect: {long_effect_str}\n- missed nodes: {missed_nodes_str}\n"
288
+
289
+ def get_node_info(self) -> str:
290
+ node_info = ""
291
+ for i in range(len(st.session_state.df_tour)):
292
+ node_info += f"- {st.session_state.df_tour['destination'][i]}: {st.session_state.df_tour['remarks'][i]}\n"
293
+ return node_info
models/solvers/concorde/concorde.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import scipy
3
+ import numpy as np
4
+ import os
5
+ import datetime
6
+ import subprocess
7
+ import models.solvers.concorde.concorde_utils as concorde_utils
8
+ import glob
9
+ import random
10
+
11
+ class ConcordeTSP(nn.Module):
12
+ def __init__(self, large_value=1e+6, scaling=False, random_seed=1234, solver_dir="models/solvers/concorde/src/TSP", io_dir="concorde_io_files"):
13
+ self.random_seed = random_seed
14
+ self.large_value = large_value
15
+ self.scaling = scaling
16
+ self.solver_dir = solver_dir
17
+ self.io_dir = io_dir
18
+ self.redirector_stdout = concorde_utils.Redirector(fd=concorde_utils.STDOUT)
19
+ self.redirector_stderr = concorde_utils.Redirector(fd=concorde_utils.STDERR)
20
+ os.makedirs(io_dir, exist_ok=True)
21
+
22
+ def get_instance_name(self):
23
+ now = datetime.datetime.now()
24
+ random_value = random.random() # for avoiding duplicated file name
25
+ instance_name = f"{os.getpid()}_{random_value}_{now.strftime('%Y%m%d_%H%M%S%f')}"
26
+ return instance_name
27
+
28
+ def write_instance(self, node_feats, fixed_paths=None, instance_name=None):
29
+ if instance_name is None:
30
+ instance_name = self.get_instance_name()
31
+ instance_fname = f"{self.io_dir}/{instance_name}.tsp"
32
+ tour_fname = f"{self.io_dir}/{instance_name}.sol"
33
+ with open(instance_fname, "w") as f:
34
+ f.write(f"NAME : {instance_name}\n")
35
+ f.write(f"TYPE : TSP\n")
36
+ f.write(f"DIMENSION : {len(node_feats['coords'])}\n")
37
+ self.write_data(f, node_feats, fixed_paths)
38
+ f.write("EOF\n")
39
+ return instance_fname, tour_fname
40
+
41
+ def write_data(self, f, node_feats, fixed_paths=None):
42
+ coords = node_feats["coords"]
43
+ if fixed_paths is None:
44
+ f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
45
+ f.write("NODE_COORD_SECTION\n")
46
+ for i in range(len(coords)):
47
+ f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
48
+ else:
49
+ f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
50
+ f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
51
+ f.write("EDGE_WEIGHT_SECTION\n")
52
+ dist = scipy.spatial.distance.cdist(coords, coords).round().astype(np.int64)
53
+ for i in range(len(fixed_paths)):
54
+ curr_id = fixed_paths[i]
55
+ if i != 0 and i != len(fixed_paths) - 1:
56
+ # NOTE: concorde TSP seems to use int32, so 1e+9 occurs overflow.
57
+ # 1e+8 could also do the same when N (tour length) is large.
58
+ dist[curr_id, :] = 1e+8; dist[:, curr_id] = 1e+8
59
+ if i != 0:
60
+ prev_id = fixed_paths[i - 1]
61
+ dist[prev_id, curr_id] = 0; dist[curr_id, prev_id] = 0
62
+ if i != len(fixed_paths) - 1:
63
+ next_id = fixed_paths[i + 1]
64
+ dist[curr_id, next_id] = 0; dist[next_id, curr_id] = 0
65
+ f.write("\n".join([
66
+ " ".join(map(str, row))
67
+ for row in dist
68
+ ]))
69
+
70
+ def solve(self, node_feats, fixed_paths=None, instance_name=None):
71
+ if self.scaling:
72
+ node_feats = self.preprocess_data(node_feats)
73
+ self.redirector_stdout.start()
74
+ self.redirector_stderr.start()
75
+ instance_fname, tour_fname = self.write_instance(node_feats, fixed_paths, instance_name)
76
+ subprocess.run(f"{self.solver_dir}/concorde -o {tour_fname} -x {instance_fname}", shell=True) # run Concorde
77
+ self.redirector_stderr.stop()
78
+ self.redirector_stdout.stop()
79
+ tours = self.read_tour(tour_fname)
80
+ # remove dump (?) files
81
+ try:
82
+ os.remove(instance_fname); os.remove(tour_fname)
83
+ except OSError as e:
84
+ pass
85
+ fname_list = glob.glob("*.sol")
86
+ fname_list.extend(glob.glob("*.res"))
87
+ for fname in fname_list:
88
+ try:
89
+ os.remove(fname)
90
+ except OSError as e:
91
+ # do nothing
92
+ pass
93
+ # subprocess.run(f"rm {instance_name}.sol", shell=True)
94
+ return tours
95
+
96
+ def read_tour(self, tour_fname):
97
+ """
98
+ Parameters
99
+ ----------
100
+ tour_fname: str
101
+ path to an output tour
102
+
103
+ Returns
104
+ -------
105
+ tour: 2d list [num_vehicles(1) x seq_length]
106
+ """
107
+ if not os.path.exists(tour_fname): # fails to solve the instance
108
+ return
109
+
110
+ tour = []
111
+ with open(tour_fname, "r") as f:
112
+ for i, line in enumerate(f):
113
+ if i == 0:
114
+ continue
115
+ read_tour = line.split()
116
+ tour.extend(read_tour)
117
+ tour.append(tour[0])
118
+ return [list(map(int, tour))]
119
+
120
+ def preprocess_data(self, node_feats):
121
+ # convert float to integer approximately
122
+ return {
123
+ key: (node_feat * self.large_value).astype(np.int64)
124
+ if key == "coords" else
125
+ node_feat
126
+ for key, node_feat in node_feats.items()
127
+ }
models/solvers/concorde/concorde_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+
5
+ # for stopping std out from concorde solver
6
+ # https://github.com/machine-reasoning-ufrgs/TSP-GNN/blob/master/redirector.py
7
+ STDOUT = 1
8
+ STDERR = 2
9
+
10
+ class Redirector(object):
11
+ def __init__(self, fd=STDOUT):
12
+ self.fd = fd
13
+ self.started = False
14
+
15
+ def start(self):
16
+ if not self.started:
17
+ self.tmpfd, self.tmpfn = tempfile.mkstemp()
18
+
19
+ self.oldhandle = os.dup(self.fd)
20
+ os.dup2(self.tmpfd, self.fd)
21
+ os.close(self.tmpfd)
22
+
23
+ self.started = True
24
+
25
+ def flush(self):
26
+ if self.fd == STDOUT:
27
+ sys.stdout.flush()
28
+ elif self.fd == STDERR:
29
+ sys.stderr.flush()
30
+
31
+ def stop(self):
32
+ if self.started:
33
+ self.flush()
34
+ os.dup2(self.oldhandle, self.fd)
35
+ os.close(self.oldhandle)
36
+ tmpr = open(self.tmpfn, 'rb')
37
+ output = tmpr.read()
38
+ tmpr.close() # this also closes self.tmpfd
39
+ os.unlink(self.tmpfn)
40
+
41
+ self.started = False
42
+ return output
43
+ else:
44
+ return None
45
+
46
+ class RedirectorOneFile(object):
47
+ def __init__(self, fd=STDOUT):
48
+ self.fd = fd
49
+ self.started = False
50
+ self.inited = False
51
+
52
+ self.initialize()
53
+
54
+ def initialize(self):
55
+ if not self.inited:
56
+ self.tmpfd, self.tmpfn = tempfile.mkstemp()
57
+ self.pos = 0
58
+ self.tmpr = open(self.tmpfn, 'rb')
59
+ self.inited = True
60
+
61
+ def start(self):
62
+ if not self.started:
63
+ self.oldhandle = os.dup(self.fd)
64
+ os.dup2(self.tmpfd, self.fd)
65
+ self.started = True
66
+
67
+ def flush(self):
68
+ if self.fd == STDOUT:
69
+ sys.stdout.flush()
70
+ elif self.fd == STDERR:
71
+ sys.stderr.flush()
72
+
73
+ def stop(self):
74
+ if self.started:
75
+ self.flush()
76
+ os.dup2(self.oldhandle, self.fd)
77
+ os.close(self.oldhandle)
78
+ output = self.tmpr.read()
79
+ self.pos = self.tmpr.tell()
80
+ self.started = False
81
+ return output
82
+ else:
83
+ return None
84
+
85
+ def close(self):
86
+ if self.inited:
87
+ self.flush()
88
+ self.tmpr.close() # this also closes self.tmpfd
89
+ os.unlink(self.tmpfn)
90
+ self.inited = False
91
+ return output
92
+ else:
93
+ return None
models/solvers/general_solver.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.solvers.ortools.ortools import ORTools
3
+ from models.solvers.lkh.lkh import LKH
4
+ from models.solvers.concorde.concorde import ConcordeTSP
5
+
6
+ class GeneralSolver(nn.Module):
7
+ def __init__(self, problem, solver_type, large_value=1e+6, scaling=True):
8
+ super().__init__()
9
+ self.problem = problem
10
+ self.large_value = large_value
11
+ self.scaling = scaling
12
+ self.solver_type = solver_type
13
+ supported_problem = {
14
+ "ortools": ["tsp", "tsptw", "pctsp", "pctsptw", "cvrp", "cvrptw"],
15
+ "lkh": ["tsp", "tsptw", "cvrp", "cvrptw"],
16
+ "concorde": ["tsp"]
17
+ }
18
+ # validate solver_type & problem
19
+ assert solver_type in supported_problem.keys(), f"Invalid solver type: {solver_type}. Please select from {supported_problem.keys()}"
20
+ assert problem in supported_problem[solver_type], f"{solver_type} does not support {problem}."
21
+ self.solver = self.get_solver(problem, solver_type)
22
+
23
+ def change_solver(self, problem, solver_type):
24
+ if self.solver_type != solver_type or self.problem != problem:
25
+ self.problem = problem
26
+ self.solver_type = solver_type
27
+ self.solver = self.get_solver(problem, solver_type)
28
+
29
+ def get_solver(self, problem, solver_type):
30
+ if solver_type == "ortools":
31
+ return ORTools(problem, self.large_value, self.scaling)
32
+ elif solver_type == "lkh":
33
+ return LKH(problem, self.large_value, self.scaling)
34
+ elif solver_type == "concorde":
35
+ assert problem == "tsp", "Concorde solver supports only TSP."
36
+ return ConcordeTSP(self.large_value, self.scaling)
37
+ else:
38
+ assert False, f"Invalid solver type: {solver_type}"
39
+
40
+ def solve(self, node_feats, fixed_paths=None, dist_matrix=None, instance_name=None):
41
+ if isinstance(self.solver, ORTools):
42
+ return self.solver.solve(node_feats, fixed_paths, dist_matrix, instance_name)
43
+ else:
44
+ return self.solver.solve(node_feats, fixed_paths, instance_name)
models/solvers/lkh/lkh.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.solvers.lkh.lkh_tsp import LKHTSP
3
+ from models.solvers.lkh.lkh_tsptw import LKHTSPTW
4
+ from models.solvers.lkh.lkh_cvrp import LKHCVRP
5
+ from models.solvers.lkh.lkh_cvrptw import LKHCVRPTW
6
+
7
+ class LKH(nn.Module):
8
+ def __init__(self, problem, large_value=1e+6, scaling=False, max_trials=10, seed=1234, lkh_dir="models/solvers/lkh/src", io_dir="lkh_io_files"):
9
+ super().__init__()
10
+ self.probelm = problem
11
+ if problem == "tsp":
12
+ self.lkh = LKHTSP(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
13
+ elif problem == "tsptw":
14
+ self.lkh = LKHTSPTW(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
15
+ elif problem == "cvrp":
16
+ self.lkh = LKHCVRP(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
17
+ elif problem == "cvrptw":
18
+ self.lkh = LKHCVRPTW(large_value, scaling, max_trials, seed, lkh_dir, io_dir)
19
+
20
+ def solve(self, node_feats, fixed_paths=None, instance_name=None):
21
+ return self.lkh.solve(node_feats, fixed_paths, instance_name)
models/solvers/lkh/lkh_base.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import datetime
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from subprocess import check_call
6
+
7
+ NODE_ID_OFFSET = 1
8
+
9
+ class LKHBase(nn.Module):
10
+ def __init__(self, problem, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh", io_dir="lkh_io_files"):
11
+ super().__init__()
12
+ self.coord_dim = 2
13
+ self.problem = problem
14
+ self.large_value = large_value
15
+ self.scaling = scaling
16
+ self.max_trials = max_trials
17
+ self.seed = seed
18
+
19
+ # I/O file settings
20
+ self.lkh_dir = lkh_dir
21
+ self.io_dir = io_dir
22
+ self.instance_path = f"{io_dir}/{self.problem}/instance"
23
+ self.param_path = f"{io_dir}/{self.problem}/param"
24
+ self.tour_path = f"{io_dir}/{self.problem}/tour"
25
+ self.log_path = f"{io_dir}/{self.problem}/log"
26
+ os.makedirs(self.instance_path, exist_ok=True)
27
+ os.makedirs(self.param_path, exist_ok=True)
28
+ os.makedirs(self.tour_path, exist_ok=True)
29
+ os.makedirs(self.log_path, exist_ok=True)
30
+
31
+ def solve(self, node_feats, fixed_paths=None, instance_name=None):
32
+ instance_fname = self.write_instance(node_feats, fixed_paths, instance_name)
33
+ param_fname, tour_fname, log_fname = self.write_para(instance_fname, instance_name)
34
+ with open(log_fname, "w") as f:
35
+ check_call([f"{self.lkh_dir}/LKH", param_fname], stdout=f) # run LKH
36
+ tours = self.read_tour(node_feats, tour_fname)
37
+ # clean intermidiate files
38
+ try:
39
+ os.remove(instance_fname); os.remove(param_fname); os.remove(tour_fname); os.remove(log_fname)
40
+ except:
41
+ pass
42
+ return tours
43
+
44
+ def get_instance_name(self):
45
+ now = datetime.datetime.now()
46
+ instance_name = f"{os.getpid()}-{now.strftime('%Y%m%d_%H%M%S%f')}"
47
+ return instance_name
48
+
49
+ def write_instance(self, node_feats, fixed_paths=None, instance_name=None):
50
+ if instance_name is None:
51
+ instance_name = self.get_instance_name()
52
+ instance_fname = f"{self.instance_path}/{instance_name}.{self.problem}"
53
+ with open(instance_fname, "w") as f:
54
+ f.write(f"NAME : {instance_name}\n")
55
+ f.write(f"TYPE : {self.problem.upper()}\n")
56
+ f.write(f"DIMENSION : {len(node_feats['coords'])}\n")
57
+ self.write_data(node_feats, f)
58
+ if fixed_paths is not None and len(fixed_paths) > 1:
59
+ fixed_paths = fixed_paths.copy()
60
+ # FIXED_EDGE_SECTION works well in TSP, but it cannot fix edges in TSPTW
61
+ # EDGE_DATA_SECTION can fix edges in both TSP and TSPTW, but the obtained tour is very poor
62
+ f.write("FIXED_EDGES_SECTION\n")
63
+ fixed_paths += NODE_ID_OFFSET # offset node id (node id starts from 1 in TSPLIB)
64
+ for i in range(len(fixed_paths) - 1):
65
+ f.write(f"{fixed_paths[i]} {fixed_paths[i+1]}\n")
66
+ # f.write("EDGE_DATA_FORMAT : EDGE_LIST\n")
67
+ # f.write("EDGE_DATA_SECTION\n")
68
+ # avail_edges = self.get_avail_edges(node_feats, fixed_paths)
69
+ # avail_edges += 1 # offset node id (node id starts from 1 in TSPLIB)
70
+ # for i in range(len(avail_edges)):
71
+ # f.write(f"{avail_edges[i][0]} {avail_edges[i][1]}\n")
72
+ f.write("EOF\n")
73
+ return instance_fname
74
+
75
+ def write_data(self, node_feats, f):
76
+ raise NotImplementedError
77
+
78
+ def get_avail_edges(self, node_feats, fixed_paths):
79
+ num_nodes = len(node_feats["coords"])
80
+ avail_edges = []
81
+ # add fixed edges
82
+ for i in range(len(fixed_paths) - 1):
83
+ avail_edges.append([fixed_paths[i], fixed_paths[i + 1]])
84
+
85
+ # add rest avaialbel edges
86
+ visited = np.array([0] * num_nodes)
87
+ for id in fixed_paths:
88
+ visited[id] = 1
89
+ visited[fixed_paths[0]] = 0
90
+ visited[fixed_paths[-1]] = 0
91
+ mask = visited < 1
92
+ node_id = np.arange(num_nodes)
93
+ feasible_node_id = node_id[mask]
94
+ for j in range(len(feasible_node_id) - 1):
95
+ for i in range(j + 1, len(feasible_node_id)):
96
+ avail_edges.append([feasible_node_id[j], feasible_node_id[i]])
97
+ return np.array(avail_edges)
98
+
99
+ def write_para(self, instance_fname, instance_name=None):
100
+ if instance_name is None:
101
+ instance_name = self.get_instance_name()
102
+ param_fname = f"{self.param_path}/{instance_name}.param"
103
+ tour_fname = f"{self.tour_path}/{instance_name}.tour"
104
+ log_fname = f"{self.log_path}/{instance_name}.log"
105
+ with open(param_fname, "w") as f:
106
+ f.write(f"PROBLEM_FILE = {instance_fname}\n")
107
+ f.write(f"MAX_TRIALS = {self.max_trials}\n")
108
+ f.write("MOVE_TYPE = 5\nPATCHING_C = 3\nPATCHING_A = 2\nRUNS = 1\n")
109
+ f.write(f"SEED = {self.seed}\n")
110
+ f.write(f"OUTPUT_TOUR_FILE = {tour_fname}\n")
111
+ return param_fname, tour_fname, log_fname
112
+
113
+ def read_tour(self, node_feats, tour_fname):
114
+ """
115
+ Parameters
116
+ ----------
117
+ output_filename: str
118
+ path to a file where optimal tour is written
119
+ Returns
120
+ -------
121
+ tour: 2d list [num_vehicles x seq_length]
122
+ a set of node ids indicating visit order
123
+ """
124
+ if not os.path.exists(tour_fname):
125
+ return # found no feasible solution
126
+
127
+ with open(tour_fname, "r") as f:
128
+ tour = []
129
+ is_tour_section = False
130
+ for line in f:
131
+ line = line.strip()
132
+ if line == "TOUR_SECTION":
133
+ is_tour_section = True
134
+ continue
135
+ if is_tour_section:
136
+ if line != "-1":
137
+ tour.append(int(line) - NODE_ID_OFFSET)
138
+ else:
139
+ tour.append(tour[0])
140
+ break
141
+ # convert 1d -> 2d list
142
+ num_nodes = len(node_feats["coords"])
143
+ tour = np.array(tour)
144
+ # NOTE: node_id >= num_nodes indicates the depot node.
145
+ # That is because LKH uses dummy nodes of which locations are the same as the depot and demands = -capacity?
146
+ # I'm not sure where the behavior is documented, but the author of NeuroLKH reads output files like that.
147
+ # please refer to https://github.com/liangxinedu/NeuroLKH/blob/main/CVRPTWdata_generate.py#L132
148
+ tour[tour >= num_nodes] = 0
149
+ # remove subsequent zeros
150
+ tour = tour[np.diff(np.concatenate(([1], tour))).nonzero()]
151
+ loc0 = (tour == 0).nonzero()[0]
152
+ num_vehicles = len(loc0) - 1
153
+ tours = []
154
+ for vehicle_id in range(num_vehicles):
155
+ vehicle_tour = tour[loc0[vehicle_id]:loc0[vehicle_id+1]+1].tolist()
156
+ tours.append(vehicle_tour)
157
+ return tours # offset to make the first index 0
models/solvers/lkh/lkh_cvrp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ from models.solvers.lkh.lkh_base import LKHBase
4
+
5
+ class LKHCVRP(LKHBase):
6
+ def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
7
+ problem = "cvrp"
8
+ super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
9
+
10
+ def write_data(self, node_feats, f):
11
+ """
12
+ Paramters
13
+ ---------
14
+ node_feats: dict of np.array
15
+ coords: np.array [num_nodes x coord_dim]
16
+ demand: np.array [num_nodes x 1]
17
+ capacity: np.array [1]
18
+ """
19
+ coords = node_feats["coords"]
20
+ demand = node_feats["demand"]
21
+ capacity = node_feats["capacity"][0]
22
+ num_nodes = len(coords)
23
+ if self.scaling:
24
+ coords = coords * self.large_value
25
+ # NOTE: In CVRP, LKH can automatically obtain optimal vehicle size.
26
+ # However it cannot in CVRPTW (please check lkh_cvrptw.py).
27
+ # EDGE_WEIGHT_SECTION
28
+ f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
29
+ # CAPACITY
30
+ f.write("CAPACITY : " + str(capacity) + "\n")
31
+ # NODE_COORD_SECTION
32
+ f.write("NODE_COORD_SECTION\n")
33
+ for i in range(num_nodes):
34
+ f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
35
+ # DEMAND_SECTION
36
+ f.write("DEMAND_SECTION\n")
37
+ for i in range(num_nodes):
38
+ f.write(f" {i + 1} {str(demand[i])}\n")
39
+ # DEPOT SECTION
40
+ f.write("DEPOT_SECTION\n")
41
+ f.write("1\n")
models/solvers/lkh/lkh_cvrptw.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ from models.solvers.lkh.lkh_base import LKHBase
4
+
5
+ class LKHCVRPTW(LKHBase):
6
+ def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
7
+ problem = "cvrptw"
8
+ super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
9
+
10
+ def write_data(self, node_feats, f):
11
+ """
12
+ Paramters
13
+ ---------
14
+ node_feats: dict of np.array
15
+ coords: np.array [num_nodes x coord_dim]
16
+ demand: np.array [num_nodes x 1]
17
+ capacity: np.array [1]
18
+ time_window: np.array [num_nodes x 2(start, end)]
19
+ """
20
+ coords = node_feats["coords"]
21
+ demand = node_feats["demand"]
22
+ capacity = node_feats["capacity"][0]
23
+ time_window = node_feats["time_window"]
24
+ num_nodes = len(coords)
25
+ if self.scaling:
26
+ coords = coords * self.large_value
27
+ time_window = time_window * self.large_value
28
+ # VEHICLES
29
+ # As the number of unused vehicles is also included to penalty in default,
30
+ # we have to modify Penalty_CVRTW.c in LKH SRC directory.
31
+ # Comment out the following part, which corresponds to penaly of unsed vehicles:
32
+ # 42 if (MTSPMinSize >= 1 && Size < MTSPMinSize)
33
+ # 43 P += MTSPMinSize - Size;
34
+ # 44 if (Size > MTSPMaxSize)
35
+ # 45 P += Size - MTSPMaxSize;
36
+ # After the modification, we can automatically obtain optimal vehicle size
37
+ # by setting large vehicle size (e.g. >20) here
38
+ f.write("VEHICLES : 20\n")
39
+ # CAPACITY
40
+ f.write("CAPACITY : " + str(capacity) + "\n")
41
+ # EDGE_WEIGHT_SECTION
42
+ f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
43
+ # NODE_COORD_SECTION
44
+ f.write("NODE_COORD_SECTION\n")
45
+ for i in range(num_nodes):
46
+ f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
47
+ # DEMAND_SECTION
48
+ f.write("DEMAND_SECTION\n")
49
+ for i in range(num_nodes):
50
+ f.write(f" {i + 1} {str(demand[i])}\n")
51
+ # TIME_WINDOW_SECTION
52
+ f.write("TIME_WINDOW_SECTION\n")
53
+ f.write("\n".join([
54
+ "{}\t{}\t{}".format(i + 1, l, u)
55
+ for i, (l, u) in enumerate(time_window)
56
+ ]))
57
+ # DEPOT SECTION
58
+ f.write("DEPOT_SECTION\n")
59
+ f.write("1\n")
models/solvers/lkh/lkh_tsp.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.solvers.lkh.lkh_base import LKHBase
2
+
3
+ class LKHTSP(LKHBase):
4
+ def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
5
+ problem = "tsp"
6
+ super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
7
+
8
+ def write_data(self, node_feats, f):
9
+ coords = node_feats["coords"]
10
+ if self.scaling:
11
+ coords = coords * self.large_value
12
+ f.write("EDGE_WEIGHT_TYPE : EUC_2D\n")
13
+ f.write("NODE_COORD_SECTION\n")
14
+ for i in range(len(coords)):
15
+ f.write(f" {i + 1} {str(coords[i][0])[:10]} {str(coords[i][1])[:10]}\n")
models/solvers/lkh/lkh_tsptw.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ from models.solvers.lkh.lkh_base import LKHBase
4
+
5
+ class LKHTSPTW(LKHBase):
6
+ def __init__(self, large_value=1e+6, scaling=False, max_trials=1000, seed=1234, lkh_dir="models/solvers/lkh/", io_dir="lkh_io_files"):
7
+ problem = "tsptw"
8
+ super().__init__(problem, large_value, scaling, max_trials, seed, lkh_dir, io_dir)
9
+
10
+ def write_data(self, node_feats, f):
11
+ coord_dim = 2
12
+ coords = node_feats["coords"]
13
+ if self.scaling:
14
+ coords = coords * self.large_value
15
+ time_window = node_feats["time_window"].astype(np.int64)
16
+ dist = scipy.spatial.distance.cdist(coords, coords).round().astype(np.int64)
17
+ f.write("EDGE_WEIGHT_TYPE : EXPLICIT\n")
18
+ f.write("EDGE_WEIGHT_FORMAT : FULL_MATRIX\n")
19
+ f.write("EDGE_WEIGHT_SECTION\n")
20
+ f.write("\n".join([
21
+ " ".join(map(str, row))
22
+ for row in dist
23
+ ]))
24
+ f.write("\n")
25
+ f.write("TIME_WINDOW_SECTION\n")
26
+ f.write("\n".join([
27
+ "{}\t{}\t{}".format(i + 1, l, u)
28
+ for i, (l, u) in enumerate(time_window)
29
+ ]))
30
+ f.write("\n")
31
+ f.write("DEPOT_SECTION\n")
32
+ f.write("1\n")
models/solvers/ortools/ortools.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.solvers.ortools.ortools_tsp import ORToolsTSP
3
+ from models.solvers.ortools.ortools_tsptw import ORToolsTSPTW
4
+ from models.solvers.ortools.ortools_pctsp import ORToolsPCTSP
5
+ from models.solvers.ortools.ortools_pctsptw import ORToolsPCTSPTW
6
+ from models.solvers.ortools.ortools_cvrp import ORToolsCVRP
7
+ from models.solvers.ortools.ortools_cvrptw import ORToolsCVRPTW
8
+
9
+ class ORTools(nn.Module):
10
+ def __init__(self, problem, large_value=1e+6, scaling=False):
11
+ super().__init__()
12
+ self.coord_dim = 2
13
+ self.problem = problem
14
+ self.large_value = large_value
15
+ self.scaling = scaling
16
+ self.ortools = self.get_ortools(problem)
17
+
18
+ def get_ortools(self, problem):
19
+ """
20
+ Parameters
21
+ ----------
22
+ problem: str
23
+ problem type
24
+
25
+ Returns
26
+ -------
27
+ ortools: ortools for the specified problem
28
+ """
29
+ if problem == "tsp":
30
+ return ORToolsTSP(self.large_value, self.scaling)
31
+ elif problem == "tsptw":
32
+ return ORToolsTSPTW(self.large_value, self.scaling)
33
+ elif problem == "pctsp":
34
+ return ORToolsPCTSP(self.large_value, self.scaling)
35
+ elif problem == "pctsptw":
36
+ return ORToolsPCTSPTW(self.large_value, self.scaling)
37
+ elif problem == "cvrp":
38
+ return ORToolsCVRP(self.large_value, self.scaling)
39
+ elif problem == "cvrptw":
40
+ return ORToolsCVRPTW(self.large_value, self.scaling)
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ def solve(self, node_feats, fixed_paths=None, dist_martix=None, instance_name=None):
45
+ """
46
+ Parameters
47
+ ----------
48
+ node_feats: np.array [num_nodes x node_dim]
49
+ fixed_paths: np.array [cf_step]
50
+ scaling: bool
51
+ whether or not coords are muliplied by a large value
52
+ to convert float-coods into int-coords
53
+
54
+ Returns
55
+ -------
56
+ tour: np.array [seq_length]
57
+ """
58
+ return self.ortools.solve(node_feats, fixed_paths, dist_martix, instance_name)