Facepalm0 commited on
Commit
a821f69
·
verified ·
1 Parent(s): 16e53c1

Upload search_old.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. search_old.py +215 -0
search_old.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import itertools
3
+ import random
4
+ from make_env import GridWorldEnv
5
+ from concurrent.futures import ProcessPoolExecutor
6
+
7
+ class Algorithm_Agent():
8
+ def __init__(self, num_categories, grid_size, grid, loc):
9
+ self.num_categories = num_categories
10
+ self.grid_size = grid_size
11
+ self.grid = grid
12
+ self.current_loc = [loc[0], loc[1]]
13
+ self.path, self.path_category = self.arrange_points()
14
+ # print('Path generated.')
15
+ self.actions = self.plan_action()
16
+ # print('Actions generated.')
17
+
18
+ def calculate_length(self, path, category_path, elim_path):
19
+ lengths = np.sum(np.abs(np.array(path[:-1]) - np.array(path[1:])), axis=1)
20
+ motion_length = np.sum(lengths) # motion path
21
+ cum_lengths = np.cumsum(lengths)[::-1] / 14.4 # cumulative length
22
+ load_length = np.sum(cum_lengths) - 4 * np.sum(np.array(cum_lengths) * np.array(elim_path[:-1]))
23
+
24
+ return motion_length + load_length
25
+
26
+
27
+ def get_elim_path(self, category_path):
28
+ elim_path = [0] * len(category_path)
29
+ for i in range(len(category_path)):
30
+ if i > 0:
31
+ previous_caterogy_path = category_path[:i]
32
+ # 统计previous_caterogy_path中,与category_path[i]同一类别的元素的个数
33
+ same_category_count = previous_caterogy_path.count(category_path[i])
34
+ if (same_category_count + 1) % 4 == 0 and same_category_count != 0:
35
+ elim_path[i] = 1
36
+ return elim_path
37
+
38
+
39
+ def find_shortest_path(self, points):
40
+ min_path = None
41
+ min_length = float('inf')
42
+ for perm in itertools.permutations(points):
43
+ perm = np.array(perm) # 转换为numpy数组
44
+ # 简化计算方式
45
+ diffs = np.abs(perm[1:] - perm[:-1])
46
+ length = np.sum(diffs)
47
+ if length < min_length:
48
+ min_length = length
49
+ min_path = perm.tolist()
50
+ return min_path, min_length
51
+
52
+ def insert_point(self, path, category_path, elim_path, point, category):
53
+ min_length = float('inf')
54
+ best_position = None
55
+ for i in range(len(path) + 1):
56
+ new_path, new_category_path = path.copy(), category_path.copy()
57
+ new_path.insert(i, point)
58
+ new_category_path.insert(i, category)
59
+ new_elim_path = self.get_elim_path(new_category_path)
60
+ if len(new_path) > 12:
61
+ a=1
62
+ length = self.calculate_length(new_path, new_category_path, new_elim_path)
63
+ if length < min_length:
64
+ min_length = length
65
+ best_position = i
66
+ return best_position
67
+
68
+ def try_single_optimization(self, args):
69
+ """
70
+ 将函数改造为接收单个参数的形式,便于进程池调用
71
+ """
72
+ path, category_path = args
73
+ path = path.copy()
74
+ category_path = category_path.copy()
75
+
76
+ # 随机选择一个点
77
+ index = random.randint(0, len(path) - 1)
78
+ point = path.pop(index)
79
+ category = category_path.pop(index)
80
+
81
+ # 尝试重新插入
82
+ elim_path = self.get_elim_path(category_path)
83
+ position = self.insert_point(path, category_path, elim_path, point, category)
84
+
85
+ # 插入到最优位置
86
+ path.insert(position, point)
87
+ category_path.insert(position, category)
88
+
89
+ return (path, category_path,
90
+ self.calculate_length(path, category_path, self.get_elim_path(category_path)))
91
+
92
+ def optimize_path_parallel(self, initial_path, initial_category_path, num_iterations=1000):
93
+ """
94
+ 新增的并行优化函数
95
+ """
96
+ chunk_size = 125
97
+ num_processes = num_iterations // chunk_size
98
+
99
+ # 准备参数
100
+ args_list = [(initial_path.copy(), initial_category_path.copy())
101
+ for _ in range(num_iterations)]
102
+
103
+ best_path, best_category_path = initial_path.copy(), initial_category_path.copy()
104
+ best_length = float('inf')
105
+
106
+ # 使用进程池
107
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
108
+ # 并行执行优化
109
+ results = list(executor.map(self.try_single_optimization,
110
+ args_list,
111
+ chunksize=chunk_size))
112
+
113
+ # 找出最佳结果
114
+ for path, category_path, length in results:
115
+ if length < best_length:
116
+ best_length = length
117
+ best_path = path
118
+ best_category_path = category_path
119
+
120
+ return best_path, best_category_path
121
+
122
+ def arrange_points(self):
123
+ points_by_category = {i: [] for i in random.sample(range(self.num_categories), self.num_categories)} # Group points by category
124
+ for x in range(self.grid_size[0]):
125
+ for y in range(self.grid_size[1]):
126
+ category = self.grid[x, y]
127
+ if category != -1:
128
+ points_by_category[category].append([x, y]) # Store the position of the item
129
+
130
+ path = [] # Initialize the path
131
+ category_path = []
132
+ for category, points in points_by_category.items(): # Process each category
133
+ while points: # Process all points in the category
134
+ if len(points) >= 4: # If there are at least 4 points, find the shortest path for the first 4 points
135
+ subset = points[:4]
136
+ points = points[4:]
137
+ else:
138
+ subset = points
139
+ points = []
140
+
141
+ if len(path) == 0: # If the path has only the loc, find the shortest path for the subset
142
+ path, _ = self.find_shortest_path(subset)
143
+ category_path = [category] * len(path)
144
+ else:
145
+ for point in subset:
146
+ elim_path = self.get_elim_path(category_path)
147
+ position = self.insert_point(path, category_path, elim_path, point, category)
148
+ path.insert(position, point)
149
+ category_path.insert(position, category)
150
+
151
+ # print(f'category: {category}, category_path: {category_path}\n')
152
+ # # 排列好第一轮后,再次调整顺序
153
+ # # 从序列中随机剔除一个元素,然后插入到其他位置,使得路径长度最短
154
+ # for i in range(1000):
155
+ # index = random.randint(0, len(path) - 1)
156
+ # point = path.pop(index)
157
+ # category = category_path.pop(index)
158
+ # elim_path = self.get_elim_path(category_path)
159
+ # position = self.insert_point(path, category_path, elim_path, point, category)
160
+ # path.insert(position, point)
161
+ # category_path.insert(position, category)
162
+
163
+ # 使用并行优化替换原来的循环
164
+ path, category_path = self.optimize_path_parallel(path, category_path)
165
+
166
+ return path, category_path
167
+
168
+ def plan_action(self):
169
+ actions = []
170
+ for i in range(len(self.path)):
171
+ while self.current_loc[0] != self.path[i][0] or self.current_loc[1] != self.path[i][1]:
172
+ if self.current_loc[0] < self.path[i][0]:
173
+ actions.append(0)
174
+ self.current_loc = [self.current_loc[0] + 1, self.current_loc[1]]
175
+ elif self.current_loc[1] < self.path[i][1]:
176
+ actions.append(1)
177
+ self.current_loc = [self.current_loc[0], self.current_loc[1] + 1]
178
+ elif self.current_loc[0] > self.path[i][0]:
179
+ actions.append(2)
180
+ self.current_loc = [self.current_loc[0] - 1, self.current_loc[1]]
181
+ else:
182
+ actions.append(3)
183
+ self.current_loc = [self.current_loc[0], self.current_loc[1] - 1]
184
+ actions.append(4)
185
+ # print(f'actions: {actions}\n')
186
+ return actions
187
+
188
+ def search(grid, loc, pred_grid, pred_loc, num_iterations=30):
189
+ env = GridWorldEnv()
190
+ optim_actions, optim_reward = None, 0
191
+ for i in range(num_iterations):
192
+ env.reset()
193
+ env.grid, env.loc = grid.copy(), loc.copy()
194
+ agent = Algorithm_Agent(env.num_categories, env.grid_size, pred_grid, pred_loc)
195
+ cumulated_reward = 0
196
+ for action in agent.actions:
197
+ obs, reward, done, truncated, info = env.step(action)
198
+ cumulated_reward += reward
199
+ if cumulated_reward > optim_reward:
200
+ optim_actions, optim_reward = agent.actions, cumulated_reward
201
+ print(f'{i}:', cumulated_reward)
202
+ print(f'Optim reward: {optim_reward}')
203
+ return optim_actions
204
+
205
+
206
+ if __name__ == "__main__":
207
+ for _ in range(20):
208
+ test_env = GridWorldEnv()
209
+ test_env.reset()
210
+ grid, loc = test_env.grid.copy(), test_env.loc.copy()
211
+ pred_grid, pred_loc = test_env.grid.copy(), test_env.loc.copy()
212
+ loc_1, loc_2, loc_3, loc_4, loc_5 = random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2)
213
+ a, b, c, d, e = pred_grid[loc_1[0], loc_1[1]], pred_grid[loc_2[0], loc_2[1]], pred_grid[loc_3[0], loc_3[1]], pred_grid[loc_4[0], loc_4[1]], pred_grid[loc_5[0], loc_5[1]]
214
+ pred_grid[loc_1[0], loc_1[1]], pred_grid[loc_2[0], loc_2[1]], pred_grid[loc_3[0], loc_3[1]], pred_grid[loc_4[0], loc_4[1]], pred_grid[loc_5[0], loc_5[1]] = b, e, a, c, d
215
+ search(grid, loc, pred_grid, pred_loc) # 使用5格混淆的grid进行搜索