|
import torch |
|
import functools |
|
|
|
from .utils import repeat |
|
|
|
|
|
class VarMeta(object): |
|
def __init__(self, clazz, **kwargs): |
|
self.clazz = clazz |
|
self._kwargs = kwargs |
|
for k, v in kwargs.items(): |
|
setattr(self, k, v) |
|
|
|
def __call__(self, problem, batch_size, sample_num): |
|
kwargs = self._kwargs.copy() |
|
kwargs['problem'] = problem.feats |
|
kwargs['batch_size'] = batch_size |
|
kwargs['sample_num'] = sample_num |
|
kwargs['worker_num'] = problem.worker_num |
|
kwargs['task_num'] = problem.task_num |
|
return self.clazz(**kwargs) |
|
|
|
|
|
def attribute_variable(name, attribute=None): |
|
return VarMeta(AttributeVariable, name=name, attribute=attribute) |
|
|
|
|
|
class AttributeVariable: |
|
def __init__(self, name, attribute, problem, batch_size, sample_num, worker_num, task_num): |
|
if attribute is None: |
|
attribute = name; |
|
|
|
self.name = name |
|
self.value = problem[attribute] |
|
|
|
|
|
def feature_variable(name, feature=None): |
|
return VarMeta(FeatureVariable, name=name, feature=feature) |
|
|
|
|
|
class FeatureVariable: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature == 'id' or feature.startswith("worker_") or feature.startswith("task_") |
|
|
|
self.name = name |
|
self.feature = problem[feature] |
|
self.value = repeat(self.feature, sample_num) |
|
|
|
|
|
def task_variable(name, feature=None): |
|
return VarMeta(TaskVariable, name=name, feature=feature) |
|
|
|
|
|
class TaskVariable: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("task_") |
|
|
|
self.name = name |
|
self.feature = problem[feature] |
|
|
|
size = list(self.feature.size()) |
|
size[0] = batch_size |
|
del size[1] |
|
self.value = self.feature.new_zeros(size) |
|
|
|
def step_task(self, b_index, p_index, t_index): |
|
self.value[b_index] = self.feature[p_index, t_index] |
|
|
|
|
|
def worker_variable(name, feature=None): |
|
return VarMeta(WorkerVariable, name=name, feature=feature) |
|
|
|
|
|
class WorkerVariable: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("worker_") |
|
|
|
self.name = name |
|
self.feature = problem[feature] |
|
|
|
size = list(self.feature.size()) |
|
size[0] = batch_size |
|
del size[1] |
|
self.value = self.feature.new_zeros(size) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
self.value[b_index] = self.feature[p_index, w_index] |
|
|
|
|
|
def worker_task_variable(name, feature=None): |
|
return VarMeta(WorkerTaskVariable, name=name, feature=feature) |
|
|
|
|
|
class WorkerTaskVariable: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("worker_task_") |
|
|
|
self.name = name |
|
self.feature = problem[feature] |
|
|
|
size = list(self.feature.size()) |
|
size[0] = batch_size |
|
|
|
del size[1] |
|
self._feature = self.feature.new_zeros(size) |
|
|
|
del size[2] |
|
self.value = self.feature.new_zeros(size) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
self._feature[b_index] = self.feature[p_index, w_index] |
|
|
|
def step_task(self, b_index, p_index, t_index): |
|
self.value[b_index] = self._feature[b_index, t_index] |
|
|
|
|
|
def worker_task_group(name, feature=None): |
|
return VarMeta(WorkerTaskGroup, name=name, feature=feature) |
|
|
|
|
|
class WorkerTaskGroup: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("task_") |
|
|
|
self.name = name |
|
self.feature = problem[feature].long() |
|
|
|
NG = self.feature.max() + 1 |
|
assert torch.all(self.feature >= 0) |
|
|
|
self.value = self.feature.new_zeros(batch_size, NG) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
self.value[b_index] = 0 |
|
|
|
def step_task(self, b_index, p_index, t_index): |
|
group = self.feature[p_index, t_index] |
|
self.value[b_index, group] += 1; |
|
|
|
|
|
def worker_task_item(name, item_id, item_num): |
|
return VarMeta(WorkerTaskItem, name=name, item_id=item_id, item_num=item_num) |
|
|
|
|
|
class WorkerTaskItem: |
|
def __init__(self, name, item_id, item_num, problem, batch_size, sample_num, worker_num, task_num): |
|
assert item_id.startswith('task_') |
|
assert item_num.startswith('task_') |
|
|
|
self.name = name |
|
self.item_id = repeat(problem[item_id], sample_num).long() |
|
self.item_num = repeat(problem[item_num], sample_num) |
|
|
|
assert torch.all(self.item_id >= 0) |
|
|
|
size = [0, 0] |
|
size[0] = self.item_id.size(0) |
|
size[1] = self.item_id.max() + 1 |
|
self.value = self.item_num.new_zeros(size) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
self.value[b_index] = 0 |
|
|
|
def step_task(self, b_index, p_index, t_index): |
|
item_id = self.item_id[b_index, t_index] |
|
item_num = self.item_num[b_index, t_index] |
|
self.value[b_index[:, None], item_id] += item_num |
|
|
|
def make_feat(self): |
|
NT = self.item_id.size(1) |
|
v = self.value[:, None, :] |
|
v = v.expand(-1, NT, -1) |
|
|
|
v = v.gather(2, self.item_id).clamp(0, 1) |
|
v = self.item_num.clamp(0, 1) - v |
|
return v.clamp(0, 1).sum(2) |
|
|
|
|
|
def task_demand_now(name, feature=None, only_this=False): |
|
return VarMeta(TaskDemandNow, name=name, feature=feature, only_this=only_this) |
|
|
|
|
|
class TaskDemandNow: |
|
def __init__(self, name, feature, only_this, problem, batch_size, sample_num, worker_num, task_num): |
|
|
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("task_") |
|
|
|
self.name = name |
|
self.only_this = only_this |
|
self._value = repeat(problem[feature], sample_num) |
|
|
|
assert self._value.dtype in \ |
|
(torch.int8, torch.int16, torch.int32, torch.int64) |
|
assert torch.all(self._value >= 0) |
|
|
|
if only_this: |
|
size = self._value.size(0) |
|
self.value = self._value.new_zeros(size) |
|
else: |
|
self.value = self._value |
|
|
|
def step_task(self, b_index, p_index, t_index, done): |
|
if done is not None: |
|
self._value[b_index, t_index] -= done |
|
|
|
if self.only_this: |
|
self.value[b_index] = self._value[b_index, t_index] |
|
else: |
|
self.value = self._value |
|
|
|
|
|
def worker_count_now(name, feature=None): |
|
return VarMeta(WorkerCountNow, name=name, feature=feature) |
|
|
|
|
|
class WorkerCountNow: |
|
def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): |
|
if feature is None: |
|
feature = name |
|
|
|
assert feature.startswith("worker_") |
|
|
|
self.name = name |
|
self.value = repeat(problem[feature], sample_num) |
|
|
|
assert self.value.dtype in \ |
|
(torch.int8, torch.int16, torch.int32, torch.int64) |
|
assert torch.all(self.value >= 0) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
self.value[b_index, w_index] -= 1 |
|
|
|
|
|
def edge_variable(name, feature, last_to_this=False, |
|
this_to_task=False, task_to_end=False, last_to_loop=False): |
|
return VarMeta(EdgeVariable, name=name, feature=feature, |
|
last_to_this=last_to_this, this_to_task=this_to_task, task_to_end=task_to_end, |
|
last_to_loop=last_to_loop) |
|
|
|
|
|
class EdgeVariable: |
|
def __init__(self, name, feature, last_to_this, this_to_task, task_to_end, last_to_loop, |
|
problem, batch_size, sample_num, worker_num, task_num): |
|
|
|
assert feature.endswith("_matrix") |
|
|
|
flags = [last_to_this, this_to_task, task_to_end, last_to_loop] |
|
assert flags.count(True) == 1 and flags.count(False) == 3 |
|
|
|
if feature is None: |
|
feature = name |
|
|
|
self.name = name |
|
self.last_to_this = last_to_this |
|
self.this_to_task = this_to_task |
|
self.task_to_end = task_to_end |
|
self.last_to_loop = last_to_loop |
|
|
|
self.worker_num = worker_num |
|
self.task_num = task_num |
|
|
|
self.feature = problem[feature] |
|
|
|
size = list(self.feature.size()) |
|
size[0] = batch_size |
|
del size[1:3] |
|
|
|
if self.this_to_task or self.task_to_end: |
|
size.insert(1, task_num) |
|
self.value = self.feature.new_zeros(size) |
|
else: |
|
self.value = self.feature.new_zeros(size) |
|
|
|
self.end_index = self.feature.new_zeros(size[0], dtype=torch.int64) |
|
self.loop_index = self.feature.new_zeros(size[0], dtype=torch.int64) |
|
self.last_index = self.feature.new_zeros(size[0], dtype=torch.int64) |
|
self.task_index = (torch.arange(task_num) + worker_num * 2)[None, :] |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
if self.last_to_this: |
|
self.value[b_index] = 0 |
|
self.last_index[b_index] = w_index |
|
elif self.this_to_task: |
|
self.do_this_to_task(b_index, p_index, w_index) |
|
elif self.task_to_end: |
|
self.end_index[b_index] = w_index + self.worker_num |
|
self.do_task_to_end(b_index, p_index) |
|
elif self.last_to_loop: |
|
self.value[b_index] = 0 |
|
self.last_index[b_index] = w_index |
|
|
|
def step_worker_end(self, b_index, p_index, w_index): |
|
this_index = w_index + self.worker_num |
|
if self.last_to_this: |
|
self.do_last_to_this(b_index, p_index, this_index) |
|
elif self.this_to_task: |
|
self.do_this_to_task(b_index, p_index, this_index) |
|
elif self.task_to_end: |
|
pass |
|
elif self.last_to_loop: |
|
self.do_last_to_loop(b_index, p_index) |
|
|
|
def step_task(self, b_index, p_index, t_index): |
|
this_index = t_index + self.worker_num * 2 |
|
if self.last_to_this: |
|
self.do_last_to_this(b_index, p_index, this_index) |
|
self.last_index[b_index] = this_index |
|
elif self.this_to_task: |
|
self.do_this_to_task(b_index, p_index, this_index) |
|
elif self.task_to_end: |
|
pass |
|
elif self.last_to_loop: |
|
last_index = self.last_index[b_index] |
|
loop_index = self.loop_index[b_index] |
|
self.loop_index[b_index] = torch.where(last_index < self.worker_num, this_index, loop_index) |
|
self.last_index[b_index] = this_index |
|
|
|
def do_last_to_this(self, b_index, p_index, this_index): |
|
last_index = self.last_index[b_index] |
|
self.value[b_index] = self.feature[p_index, last_index, this_index] |
|
|
|
def do_this_to_task(self, b_index, p_index, this_index): |
|
p_index2 = p_index[:, None] |
|
this_index2 = this_index[:, None] |
|
task_index2 = self.task_index |
|
self.value[b_index] = self.feature[p_index2, this_index2, task_index2] |
|
|
|
def do_task_to_end(self, b_index, p_index): |
|
p_index2 = p_index[:, None] |
|
task_index2 = self.task_index |
|
end_index = self.end_index[b_index] |
|
end_index2 = end_index[:, None] |
|
self.value[b_index] = self.feature[p_index2, task_index2, end_index2] |
|
|
|
def do_last_to_loop(self, b_index, p_index): |
|
loop_index = self.loop_index[b_index] |
|
last_index = self.last_index[b_index] |
|
self.value[b_index] = self.feature[p_index, last_index, loop_index] |
|
|
|
def make_feat(self): |
|
assert self.this_to_task or self.task_to_end, \ |
|
"one of [this_to_task, task_to_end] must be true" |
|
return self.value.clone() |
|
|
|
|
|
def worker_used_resource(name, edge_require=None, task_require=None, task_ready=None, worker_ready=None, task_due=None): |
|
return VarMeta(WorkerUsedResource, name=name, edge_require=edge_require, task_require=task_require, |
|
task_ready=task_ready, worker_ready=worker_ready, task_due=task_due) |
|
|
|
|
|
class WorkerUsedResource: |
|
def __init__(self, name, edge_require, task_require, task_ready, worker_ready, task_due, |
|
problem, batch_size, sample_num, worker_num, task_num): |
|
|
|
assert edge_require is None or edge_require.endswith("_matrix"), "unsupported edge: {}".format(edge_require) |
|
assert task_require is None or task_require.startswith("task_"), "unsupported task_require: {}".format( |
|
task_require) |
|
assert task_ready is None or task_ready.startswith("task_"), "unsupported task_service: {}".format(task_ready) |
|
assert worker_ready is None or worker_ready.startswith("worker_") and not worker_ready.startswith( |
|
"worker_task_") |
|
assert task_due is None or task_due.startswith("task_"), "unsupported task_due: {}".format(task_due) |
|
|
|
self.name = name |
|
|
|
self.worker_num = worker_num |
|
self.task_num = task_num |
|
|
|
if edge_require is None: |
|
self.edge_require = None |
|
else: |
|
self.edge_require = problem[edge_require] |
|
self.last_index = self.edge_require.new_zeros(batch_size, dtype=torch.int64) |
|
|
|
if task_require is None: |
|
self.task_require = None |
|
else: |
|
self.task_require = problem[task_require] |
|
self.task_require2 = repeat(self.task_require, sample_num) |
|
|
|
if task_ready is None: |
|
self.task_ready = None |
|
else: |
|
self.task_ready = problem[task_ready] |
|
|
|
if worker_ready is None: |
|
self.worker_ready = None |
|
else: |
|
self.worker_ready = problem[worker_ready] |
|
|
|
if task_due is None: |
|
self.task_due = None |
|
else: |
|
self.task_due = problem[task_due] |
|
|
|
tenors = [self.edge_require, self.task_require, self.task_ready, self.worker_ready] |
|
tenors = list(filter(lambda x: x is not None, tenors)) |
|
assert tenors, "at least one of edge_require, task_require, task_ready, worker_ready is required!" |
|
|
|
size = list(tenors[0].size()) |
|
size[0] = batch_size |
|
if self.edge_require is None: |
|
del size[1] |
|
else: |
|
del size[1:3] |
|
|
|
self.value = tenors[0].new_zeros(size) |
|
|
|
def step_worker_start(self, b_index, p_index, w_index): |
|
if self.worker_ready is None: |
|
self.value[b_index] = 0 |
|
else: |
|
self.value[b_index] = self.worker_ready[p_index, w_index] |
|
|
|
if self.edge_require is not None: |
|
self.last_index[b_index] = w_index |
|
|
|
def step_worker_end(self, b_index, p_index, w_index): |
|
if self.edge_require is not None: |
|
last_index = self.last_index[b_index] |
|
this_index = w_index + self.worker_num |
|
self.value[b_index] += self.edge_require[p_index, last_index, this_index] |
|
self.last_index[b_index] = this_index; |
|
|
|
def step_task(self, b_index, p_index, t_index, done): |
|
if done is None: |
|
if self.edge_require is not None: |
|
last_index = self.last_index[b_index] |
|
this_index = t_index + (self.worker_num * 2) |
|
self.value[b_index] += self.edge_require[p_index, last_index, this_index] |
|
self.last_index[b_index] = this_index |
|
|
|
if self.task_ready is not None: |
|
self.value[b_index] = torch.max(self.value[b_index], self.task_ready[p_index, t_index]) |
|
|
|
else: |
|
if self.task_require is not None: |
|
if self.value.dim() == 2: |
|
done = done[:, None] |
|
self.value[b_index] += self.task_require[p_index, t_index] * done |
|
|
|
def make_feat(self): |
|
assert self.value.dim() == 2, \ |
|
"value's dim must be 2, actual: {}".format(self.value.dim()) |
|
assert self.task_require is not None, "task_require is required" |
|
|
|
v = self.value[:, None, :] + self.task_require2 |
|
return v.clamp(0, 1).sum(2, dtype=v.dtype) |
|
|
|
|
|
def worker_task_sequence(name): |
|
return VarMeta(WorkerTaskSequence, name=name) |
|
|
|
|
|
class WorkerTaskSequence: |
|
def __init__(self, name, problem, batch_size, sample_num, worker_num, task_num): |
|
self.name = name |
|
self.value = None |
|
|
|
def step_finish(self, worker_task_seq): |
|
self.value = worker_task_seq |
|
|