GreedRL / greedrl /encode.py
先坤
add greedrl
db26c81
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from .norm import Norm1D, Norm2D
from .dense import Dense
from .utils import repeat
from .feature import *
class MultiHeadAttention(nn.Module):
def __init__(self, heads, hidden_dim):
super(MultiHeadAttention, self).__init__()
assert hidden_dim % heads == 0
self.heads = heads
head_dim = hidden_dim // heads
self.alpha = 1 / math.sqrt(head_dim)
self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, x, edge):
batch_size, item_num, hidden_dim = x.size()
size = (self.heads, batch_size, item_num, -1)
x = x.reshape(-1, hidden_dim)
Q = torch.matmul(x, self.nn_Q).view(size)
K = torch.matmul(x, self.nn_K).view(size)
V = torch.matmul(x, self.nn_V).view(size)
heads_batch = self.heads * batch_size
Q = Q.view(heads_batch, item_num, -1)
K = K.view(heads_batch, item_num, -1).transpose(1, 2)
if edge is not None:
S = edge.view(heads_batch, item_num, item_num)
S = S.baddbmm(Q, K, alpha=self.alpha)
else:
S = Q.new_zeros(heads_batch, item_num, item_num)
S = S.baddbmm_(Q, K, alpha=self.alpha)
S = S.view(self.heads, batch_size, item_num, item_num)
S = F.softmax(S, dim=-1)
x = torch.matmul(S, V).permute(1, 2, 0, 3)
x = x.reshape(batch_size, item_num, -1)
x = torch.matmul(x, self.nn_O)
return x
class Encode(nn.Module):
def __init__(self, nn_args):
super(Encode, self).__init__()
self.nn_args = nn_args
self.worker_dim = nn_args['worker_dim']
self.task_dim = nn_args['task_dim']
self.edge_dim = nn_args['edge_dim']
self.embed_dict = nn_args['embed_dict']
self.feature_dict = nn_args['feature_dict']
layers = nn_args.setdefault('encode_layers', 3)
heads = nn_args.setdefault('encode_atten_heads', 8)
norm = nn_args.setdefault('encode_norm', 'instance')
hidden_dim = nn_args.setdefault('encode_hidden_dim', 128)
output_dim = nn_args.setdefault('decode_hidden_dim', 128)
output_heads = nn_args.setdefault('decode_atten_heads', 0)
self.heads = heads
self.layers = layers
worker_dim = max(1, sum(self.worker_dim.values()))
task_dim = max(1, sum(self.task_dim.values()))
self.nn_dense_worker_start = Dense(worker_dim, hidden_dim)
self.nn_dense_worker_end = Dense(worker_dim, hidden_dim)
self.nn_dense_task = Dense(task_dim, hidden_dim)
self.nn_norm_worker_task = Norm1D(hidden_dim, norm, True)
if len(self.edge_dim) > 0:
edge_dim = sum(self.edge_dim.values())
self.nn_dense_edge = Dense(edge_dim, heads)
self.nn_norm_edge = Norm2D(heads, norm, True)
nn_embed_dict = {}
for k, v in self.embed_dict.items():
nn_embed_dict[k] = nn.Embedding(v, hidden_dim)
self.nn_embed_dict = nn.ModuleDict(nn_embed_dict)
self.nn_attens = nn.ModuleList()
self.nn_denses = nn.ModuleList()
self.nn_norms1 = nn.ModuleList()
self.nn_norms2 = nn.ModuleList()
for i in range(layers):
self.nn_attens.append(MultiHeadAttention(heads, hidden_dim))
self.nn_denses.append(nn.Sequential(
Dense(hidden_dim, hidden_dim * 4),
Dense(hidden_dim * 4, hidden_dim, act='relu'),
))
self.nn_norms1.append(Norm1D(hidden_dim, norm, True))
self.nn_norms2.append(Norm1D(hidden_dim, norm, True))
self.nn_finish = nn.Parameter(torch.Tensor(1, 1, hidden_dim))
if output_dim != hidden_dim:
self.nn_X = nn.Parameter(torch.Tensor(hidden_dim, output_dim))
else:
self.nn_X = None
if output_heads > 0:
assert output_dim % output_heads == 0
head_dim = output_dim // output_heads
self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
else:
self.nn_K = None
self.nn_V = None
for param in self.parameters():
stdv = 1 / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, problem, batch_size, worker_num, task_num, memopt=0):
worker_start, worker_end = self.encode_worker(problem, batch_size, worker_num)
task = self.encode_task(problem, batch_size, task_num)
X = torch.cat([worker_start, worker_end, task], 1)
X = self.nn_norm_worker_task(X)
if len(self.edge_dim) > 0:
edge = self.encode_edge(problem, batch_size, worker_num, task_num)
edge = self.nn_norm_edge(edge)
edge = edge.permute(3, 0, 1, 2).contiguous()
else:
edge = None
#transformer encoding
for i in range(self.layers):
X = self.encode_layer(X, edge, i, memopt)
finish = repeat(self.nn_finish, X.size(0))
X = torch.cat([X, finish], 1)
if self.nn_X is not None:
X = torch.matmul(X, self.nn_X)
if self.nn_K is not None:
batch_size, item_num, hidden_dim = X.size()
size = (self.heads, batch_size, item_num, -1)
X2 = X.reshape(-1, hidden_dim)
K = torch.matmul(X2, self.nn_K).view(size)
V = torch.matmul(X2, self.nn_V).view(size)
else:
K = torch.ones(0)
V = torch.ones(0)
return X, K, V
def encode_layer(self, X, edge, i, memopt):
run_fn = self.encode_layer_fn(i, memopt)
if self.training and memopt > 6:
return checkpoint(run_fn, X, edge)
else:
return run_fn(X, edge)
def encode_layer_fn(self, i, memopt):
def run_fn(X, edge):
if self.training and memopt == 6:
X = X + checkpoint(self.nn_attens[i], X, edge)
else:
X = X + self.nn_attens[i](X, edge)
X = self.nn_norms1[i](X)
X = X + self.nn_denses[i](X)
X = self.nn_norms2[i](X)
return X
return run_fn
def encode_worker(self, problem, batch_size, worker_num):
feature_list = []
for k, dim in self.worker_dim.items():
f = self.feature_dict.get(k)
if isinstance(f, GlobalCategory):
v = problem[f.name]
v = self.nn_embed_dict[k](v.long())
elif isinstance(f, ContinuousFeature):
v = problem[f.name]
else:
raise Exception("unsupported feature type: {}".format(type(f)))
if v.dim() == 2:
v = v[:, :, None]
assert dim == v.size(-1), \
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
feature_list.append(v.float())
if feature_list:
x = torch.cat(feature_list, 2)
else:
x = self.nn_finish.new_ones(batch_size, worker_num, 1)
return self.nn_dense_worker_start(x), self.nn_dense_worker_end(x)
def encode_task(self, problem, batch_size, task_num):
feature_list = []
for k, dim in self.task_dim.items():
f = self.feature_dict.get(k)
if isinstance(f, SparseLocalFeature):
v = problem[f.value]
assert v.dim() == 3, \
"sparse local feature's dimension must 2, feature:{}".format(k)
v = v.clamp(0, 1).sum(2, dtype=v.dtype)
elif isinstance(f, GlobalCategory):
v = problem[f.name]
v = self.nn_embed_dict[k](v.long())
elif isinstance(f, LocalFeature):
v = problem[f.name]
assert v.dim() == 3, \
"local feature's dimension must 2, feature:{}".format(k)
v = v.clamp(0, 1).sum(2, dtype=v.dtype)
elif isinstance(f, ContinuousFeature):
v = problem[f.name]
else:
raise Exception("unsupported feature type: {}".format(type(f)))
if v.dim() == 2:
v = v[:, :, None]
assert dim == v.size(-1), \
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
feature_list.append(v.float())
if feature_list:
x = torch.cat(feature_list, 2)
else:
x = self.nn_finish.new_ones(batch_size, task_num, 1)
return self.nn_dense_task(x)
def encode_edge(self, problem, batch_size, worker_num, task_num):
NP = batch_size
NW = worker_num
NT = task_num
NWW = NW + NW
feature_list = []
for k, dim in self.edge_dim.items():
f = self.feature_dict.get(k)
if isinstance(f, LocalCategory):
assert f.name.startswith("task_")
v = problem[k]
v1 = v[:, :, None]
v2 = v[:, None, :]
v = torch.zeros(NP, NWW + NT, NWW + NT,
dtype=v.dtype, device=v.device)
v[:, NWW:, NWW:] = ((v1 == v2) & (v1 >= 0))
elif isinstance(f, LocalFeature):
assert f.name.startswith("task_")
v = problem[k].float()
dot_product = torch.matmul(v, v.transpose(-1, -2))
v_norm = v.norm(dim=2) + 1e-10
v1_norm = v_norm[:, :, None]
v2_norm = v_norm[:, None, :]
v = torch.zeros(NP, NWW + NT, NWW + NT,
dtype=v.dtype, device=v.device)
v[:, NWW:, NWW:] = dot_product / v1_norm / v2_norm
elif isinstance(f, SparseLocalFeature):
assert NP == 1
assert f.index.startswith("task_")
assert f.value.startswith("task_")
index = problem[f.index]
value = problem[f.value].float()
NV = index.max().item() + 1
spv = value.reshape(-1).tolist()
spi = index.reshape(-1).tolist()
device = value.device
spj = torch.arange(NT, device=device)
spj = spj[:, None].expand_as(index)
spj = spj.reshape(-1).tolist()
value1 = torch.sparse_coo_tensor([spj, spi], spv, (NT, NV), device=device)
value2 = torch.sparse_coo_tensor([spi, spj], spv, (NV, NT), device=device)
value1 = value1.coalesce()
value2 = value2.coalesce()
cosine = torch.sparse.mm(value1, value2).to_dense()
norm = value.norm(dim=-1).reshape(-1)
norm1 = norm[:, None].expand(-1, NT)
norm2 = norm[None, :].expand(NT, -1)
cosine = cosine / (norm1 * norm2 + 1e-10)
v = torch.zeros(NP, NWW + NT, NWW + NT,
dtype=value.dtype, device=value.device)
v[:, NWW:, NWW:] = cosine
elif isinstance(f, ContinuousFeature):
if f.name.endswith("_matrix"):
v = problem[k]
elif f.name.startswith("worker_task_"):
v = problem[k]
if v.dim() == 3:
new_v = torch.zeros(NP, NWW + NT, NWW + NT,
dtype=v.dtype, device=v.device)
else:
new_v = torch.zeros(NP, NWW + NT, NWW + NT, v.size(3),
dtype=v.dtype, device=v.device)
problem_index = torch.arange(NP, device=v.device)[:, None, None]
worker_index = torch.arange(NW, device=v.device)[None, :, None]
task_index = torch.arange(NT, device=v.device)[None, None, :] + NW + NW
new_v[problem_index, worker_index, task_index] = v
new_v[problem_index, task_index, worker_index] = v
new_v[problem_index, worker_index + NW, task_index] = v
new_v[problem_index, task_index, worker_index + NW] = v
v = new_v
else:
raise Exception("feature: {}".format(f.name))
else:
raise Exception("feature: {}, type: {}".format(k, type(f)))
if v.dim() == 3:
v = v[:, :, :, None]
assert dim == v.size(-1), \
"feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
feature_list.append(v.float())
x = torch.cat(feature_list, 3)
return self.nn_dense_edge(x)