import torch import json import math from collections import OrderedDict from .const import * from .utils import to_list from .norm import Norm1D, Norm2D from .variable import AttributeVariable, WorkerTaskSequence class PyEnv(object): def __init__(self, problem, batch_size, sample_num, nn_args): super(PyEnv, self).__init__() self._problem = problem self._batch_size = batch_size self._sample_num = sample_num self._debug = -1 self._NW = problem.worker_num self._NWW = problem.worker_num * 2 self._NT = problem.task_num self._NWWT = self._NWW + self._NT self._feats_dict = nn_args['feature_dict'] self._vars_dim = nn_args['variable_dim'] self._vars_dict = {} self._vars = [var(problem, batch_size, sample_num) for var in problem.variables] for variable in self._vars: save_variable_version(variable) assert variable.name not in self._vars_dict, \ "duplicated variable, name: {}".format(variable.name) self._vars_dict[variable.name] = variable self._constraint = problem.constraint() self._objective = problem.objective() self._worker_index = torch.full((self._batch_size,), -1, dtype=torch.int64, device=problem.device) self._batch_index = torch.arange(self._batch_size, dtype=torch.int64, device=problem.device) self._problem_index = torch.div(self._batch_index, sample_num, rounding_mode='trunc') # self._batch_index // sample_num self._feasible = torch.ones(self._batch_size, dtype=torch.bool, device=problem.device) self._cost = torch.zeros(self._batch_size, self._NT * 2, dtype=torch.float32, device=problem.device) self._mask = torch.zeros(self._batch_size, self._NWWT + 1, dtype=torch.bool, device=problem.device) self._worker_task_sequence = torch.full((self._batch_size, self._NT * 2, 3), -1, dtype=torch.int64, device=problem.device) self._step = 0 self.register_variables(self._constraint) self._finished = self._constraint.finished() if hasattr(self._constraint, 'mask_worker_start'): self.register_variables(self._constraint) mask_start = self._constraint.mask_worker_start() else: mask_start = False self._mask[:, :self._NW] = mask_start self._mask[:, self._NW:] = True if self._debug >= 0: print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") print("new env") print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n") def time(self): return self._step def step(self, chosen): with torch.no_grad(): self._do_step(chosen) def _do_step(self, chosen): if self._debug >= 0: print("----------------------------------------------------------------------") feasible = self._feasible & ~self._mask[self._problem_index, chosen] print("feasible={}".format(feasible[self._debug].tolist())) is_start = (chosen >= 0) & (chosen < self._NW) if torch.any(is_start): b_index = self._batch_index[is_start] p_index = self._problem_index[is_start] w_index = chosen[is_start] self.step_worker_start(b_index, p_index, w_index) is_end = (chosen >= self._NW) & (chosen < self._NWW) if torch.any(is_end): b_index = self._batch_index[is_end] p_index = self._problem_index[is_end] w_index = chosen[is_end] - self._NW self.step_worker_end(b_index, p_index, w_index) is_task = (chosen >= self._NWW) & (chosen < self._NWWT) if torch.any(is_task): b_index = self._batch_index[is_task] p_index = self._problem_index[is_task] t_index = chosen[is_task] - self._NWW step_task_b_index = b_index self.step_task(b_index, p_index, t_index) else: step_task_b_index = None is_finish = chosen == self._NWWT if torch.any(is_finish): b_index = self._batch_index[is_finish] self._worker_task_sequence[b_index, self._step, 0] = GRL_FINISH self._worker_task_sequence[b_index, self._step, 1] = 0 self._worker_task_sequence[b_index, self._step, 2] = -1 self.update_mask(step_task_b_index) for var in self._vars: check_variable_version(var) if self._debug >= 0: print("worker_task_sequence[{}]={}".format(self._step, self._worker_task_sequence[self._debug, self._step].tolist())) for var in self._vars: if var.value is None: print("{}={}".format(var.name, None)) elif isinstance(var, AttributeVariable): print("{}={}".format(var.name, to_list(var.value))) else: print("{}={}".format(var.name, to_list(var.value[self._debug]))) self._step += 1 if self._step >= self._cost.size(1): cost = torch.zeros(self._batch_size, self._step + self._NT, dtype=torch.float32, device=chosen.device) cost[:, 0:self._step] = self._cost; self._cost = cost worker_task_sequence = torch.full((self._batch_size, self._step + self._NT, 3), -1, dtype=torch.int64, device=chosen.device) worker_task_sequence[:, 0:self._step, :] = self._worker_task_sequence self._worker_task_sequence = worker_task_sequence def step_worker_start(self, b_index, p_index, w_index): self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_START self._worker_task_sequence[b_index, self._step, 1] = w_index self._worker_task_sequence[b_index, self._step, 2] = -1 for var in self._vars: if hasattr(var, 'step_worker_start'): var.step_worker_start(b_index, p_index, w_index) save_variable_version(var) if hasattr(self._objective, 'step_worker_start'): self.register_variables(self._objective, b_index) self.update_cost(self._objective.step_worker_start(), b_index) self._worker_index[b_index] = w_index self._mask[b_index, :self._NWW] = True self._mask[b_index, self._NWW:] = False def step_worker_end(self, b_index, p_index, w_index): self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_END self._worker_task_sequence[b_index, self._step, 1] = w_index self._worker_task_sequence[b_index, self._step, 2] = -1; for var in self._vars: if hasattr(var, 'step_worker_end'): var.step_worker_end(b_index, p_index, w_index) save_variable_version(var) if hasattr(self._objective, 'step_worker_end'): self.register_variables(self._objective, b_index) self.update_cost(self._objective.step_worker_end(), b_index) self._worker_index[b_index] = -1 self.register_variables(self._constraint, b_index) self._finished[b_index] |= self._constraint.finished() if hasattr(self._constraint, 'mask_worker_start'): mask_start = self._constraint.mask_worker_start() else: mask_start = False self._mask[b_index, :self._NW] = mask_start self._mask[b_index, self._NW:] = True def step_task(self, b_index, p_index, t_index): self._worker_task_sequence[b_index, self._step, 0] = GRL_TASK self._worker_task_sequence[b_index, self._step, 1] = t_index for var in self._vars: if not hasattr(var, 'step_task'): continue elif var.step_task.__code__.co_argcount == 4: var.step_task(b_index, p_index, t_index) else: var.step_task(b_index, p_index, t_index, None) save_variable_version(var) if hasattr(self._constraint, 'do_task'): self.register_variables(self._constraint, b_index) done = self._constraint.do_task() self._worker_task_sequence[b_index, self._step, 2] = done.long() for var in self._vars: if not hasattr(var, 'step_task'): continue elif var.step_task.__code__.co_argcount == 4: pass else: check_variable_version(var) var.step_task(b_index, p_index, t_index, done) save_variable_version(var) else: done = None if hasattr(self._objective, 'step_task'): self.register_variables(self._objective, b_index) self.update_cost(self._objective.step_task(), b_index) if hasattr(self._constraint, 'mask_worker_end'): self.register_variables(self._constraint, b_index) mask_end = self._constraint.mask_worker_end() else: mask_end = False w_index = self._NW + self._worker_index[b_index] self._mask[b_index, w_index] = mask_end self._mask[b_index, self._NWW:] = False return done def update_cost(self, cost, b_index=None): if isinstance(cost, tuple): cost, feasible = cost if b_index is None: self._feasible &= feasible else: self._feasible[b_index] &= feasible if isinstance(cost, torch.Tensor): cost = cost.float() else: assert type(cost) in (int, float), "unexpected cost's type: {}".format(type(cost)) if b_index is None: self._cost[:, self._step] = cost else: self._cost[b_index, self._step] = cost def update_mask(self, step_task_b_index): self._mask |= self._finished[:, None] self._mask[:, -1] = ~self._finished self.register_variables(self._constraint) self._mask[:, self._NWW:self._NWWT] |= self._constraint.mask_task() if step_task_b_index is not None: b_index = step_task_b_index w_index = self._NW + self._worker_index[b_index] task_mask = self._mask[b_index, self._NWW:self._NWWT] self._mask[b_index, w_index] &= ~torch.all(task_mask, 1) def batch_size(): return self._batch_size def sample_num(): return self._sample_num def mask(self): return self._mask.clone() def cost(self): return self._cost[:, 0:self._step] def feasible(self): return self._feasible def worker_task_sequence(self): return self._worker_task_sequence[:, 0:self._step] def var(self, name): return self._vars_dict[name].value def register_variables(self, obj, b_index=None, finished=False): for var in self._vars: if var.value is None or b_index is None \ or isinstance(var, AttributeVariable): value = var.value else: value = var.value[b_index] obj.__dict__[var.name] = value if not hasattr(var, 'ext_values'): continue for k, v in var.ext_values.items(): k = var.name + '_' + k obj.__dict__[k] = v[b_index] def finished(self): return self._finished def all_finished(self): return torch.all(self.finished()) def finalize(self): self._worker_task_sequence[:, self._step, 0] = GRL_FINISH self._worker_task_sequence[:, self._step, 1] = 0 self._worker_task_sequence[:, self._step, 2] = -1 for var in self._vars: if hasattr(var, 'step_finish'): var.step_finish(self.worker_task_sequence()) if hasattr(self._objective, 'step_finish'): self.register_variables(self._objective, finished=True) self.update_cost(self._objective.step_finish()) self._step += 1 def make_feat(self): with torch.no_grad(): return self.do_make_feat() def do_make_feat(self): if not self._vars_dim: return None feature_list = [] for k, dim in self._vars_dim.items(): f = self._feats_dict[k] var = self._vars_dict[f.name] v = var.make_feat() if v.dim() == 2: v = v[:, :, None] assert dim == v.size(-1), \ "feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1)) feature_list.append(v.float()) v = torch.cat(feature_list, 2) u = v.new_zeros(v.size(0), self._NWW, v.size(2)) f = v.new_zeros(v.size(0), 1, v.size(2)) v = torch.cat([u, v, f], 1).permute(0, 2, 1) v[self._mask[:, None, :].expand(v.size())] = 0 norm = v.new_ones(self._mask.size()) norm[self._mask] = 0 norm = norm.sum(1) + 1e-10 norm = norm[:, None, None] avg = v.sum(-1, keepdim=True) / norm v = v - avg std = v.norm(dim=-1, keepdim=True) / norm + 1e-10 v = v / std return v.contiguous() def save_variable_version(var): if isinstance(var.value, torch.Tensor): var.__version__ = var.value._version def check_variable_version(var): if isinstance(var.value, torch.Tensor): assert var.__version__ == var.value._version, \ "variable's value is modified, name: {}".format(var.name)