Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.distributions.categorical import Categorical | |
import torch_ac | |
from utils.other import init_params | |
class MMMemoryMultiHeadedACModel(nn.Module, torch_ac.RecurrentACModel): | |
def __init__(self, obs_space, action_space, use_memory=False, use_text=False, use_dialogue=False): | |
super().__init__() | |
# Decide which components are enabled | |
self.use_text = use_text | |
self.use_dialogue = use_dialogue | |
self.use_memory = use_memory | |
if not self.use_memory: | |
raise ValueError("You should not be using this model. Use MultiHeadedACModel instead") | |
if self.use_text: | |
raise ValueError("You should not use text but dialogue.") | |
# multi dim | |
if action_space.shape == (): | |
raise ValueError("The action space is not multi modal. Use ACModel instead.") | |
self.n_primitive_actions = action_space.nvec[0] + 1 # for talk | |
self.talk_action = int(self.n_primitive_actions) - 1 | |
self.n_utterance_actions = action_space.nvec[1:] | |
# Define image embedding | |
self.image_conv = nn.Sequential( | |
nn.Conv2d(3, 16, (2, 2)), | |
nn.ReLU(), | |
nn.MaxPool2d((2, 2)), | |
nn.Conv2d(16, 32, (2, 2)), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, (2, 2)), | |
nn.ReLU() | |
) | |
n = obs_space["image"][0] | |
m = obs_space["image"][1] | |
self.image_embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64 | |
if self.use_text or self.use_dialogue: | |
self.word_embedding_size = 32 | |
self.word_embedding = nn.Embedding(obs_space["text"], self.word_embedding_size) | |
# Define text embedding | |
if self.use_text: | |
self.text_embedding_size = 128 | |
self.text_rnn = nn.GRU(self.word_embedding_size, self.text_embedding_size, batch_first=True) | |
# Define dialogue embedding | |
if self.use_dialogue: | |
self.dialogue_embedding_size = 128 | |
self.dialogue_rnn = nn.GRU(self.word_embedding_size, self.dialogue_embedding_size, batch_first=True) | |
# Resize image embedding | |
self.embedding_size = self.image_embedding_size | |
if self.use_text: | |
self.embedding_size += self.text_embedding_size | |
if self.use_dialogue: | |
self.embedding_size += self.dialogue_embedding_size | |
if self.use_memory: | |
self.memory_rnn = nn.LSTMCell(self.embedding_size, self.embedding_size) | |
# Define actor's model | |
self.actor = nn.Sequential( | |
nn.Linear(self.embedding_size, 64), | |
nn.Tanh(), | |
nn.Linear(64, self.n_primitive_actions) | |
) | |
self.talker = nn.ModuleList([ | |
nn.Sequential( | |
nn.Linear(self.embedding_size, 64), | |
nn.Tanh(), | |
nn.Linear(64, n) | |
) for n in self.n_utterance_actions]) | |
# Define critic's model | |
self.critic = nn.Sequential( | |
nn.Linear(self.embedding_size, 64), | |
nn.Tanh(), | |
nn.Linear(64, 1) | |
) | |
# Initialize parameters correctly | |
self.apply(init_params) | |
def memory_size(self): | |
return 2*self.semi_memory_size | |
def semi_memory_size(self): | |
return self.embedding_size | |
def forward(self, obs, memory): | |
x = obs.image.transpose(1, 3).transpose(2, 3) | |
x = self.image_conv(x) | |
batch_size = x.shape[0] | |
x = x.reshape(batch_size, -1) | |
embedding = x | |
if self.use_text: | |
embed_text = self._get_embed_text(obs.text) | |
embedding = torch.cat((embedding, embed_text), dim=1) | |
if self.use_dialogue: | |
embed_dial = self._get_embed_dialogue(obs.dialogue) | |
embedding = torch.cat((embedding, embed_dial), dim=1) | |
if self.use_memory: | |
hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) | |
hidden = self.memory_rnn(embedding, hidden) | |
embedding = hidden[0] | |
memory = torch.cat(hidden, dim=1) | |
x = self.actor(embedding) | |
primitive_actions_dist = Categorical(logits=F.log_softmax(x, dim=1)) | |
x = self.critic(embedding) | |
value = x.squeeze(1) | |
utterance_actions_dists = [ | |
Categorical(logits=F.log_softmax( | |
tal(embedding), | |
dim=1, | |
)) for tal in self.talker | |
] | |
dist = [primitive_actions_dist] + utterance_actions_dists | |
return dist, value, memory | |
def sample_action(self, dist): | |
return torch.stack([d.sample() for d in dist], dim=1) | |
def calculate_log_probs(self, dist, action): | |
return torch.stack([d.log_prob(action[:, i]) for i, d in enumerate(dist)], dim=1) | |
def calculate_action_masks(self, action): | |
talk_mask = action[:, 0] == self.talk_action | |
mask = torch.stack( | |
(torch.ones_like(talk_mask), talk_mask, talk_mask), | |
dim=1).detach() | |
assert action.shape == mask.shape | |
return mask | |
def construct_final_action(self, action): | |
act_mask = action[:, 0] != self.n_primitive_actions - 1 | |
nan_mask = np.array([ | |
np.array([1, np.nan, np.nan]) if t else np.array([np.nan, 1, 1]) for t in act_mask | |
]) | |
action = nan_mask*action | |
return action | |
def _get_embed_text(self, text): | |
_, hidden = self.text_rnn(self.word_embedding(text)) | |
return hidden[-1] | |
def _get_embed_dialogue(self, dial): | |
_, hidden = self.dialogue_rnn(self.word_embedding(dial)) | |
return hidden[-1] | |