|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from Model.CLIP.cn_clip.clip import load_from_name |
|
import cn_clip.clip as clip |
|
from PIL import Image |
|
from torch.autograd import Variable |
|
from fur_rl.models.StageOne import Stage1 |
|
import copy |
|
import PIL |
|
|
|
class Retriever(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
self.conv1 = nn.Conv1d(in_channels=52, out_channels=64, kernel_size=5, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv4 = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.fc1 = nn.Linear(in_features=516, out_features=50, bias=True) |
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action): |
|
env_feature, hx = self.get_env_feature(img, txt) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
x = torch.cat((env_feature, action, hx), dim=1) |
|
|
|
x = self.relu(self.conv1(x)) |
|
x = self.relu(self.conv2(x)) |
|
x = self.relu(self.conv3(x)) |
|
x = self.relu(self.conv4(x)) |
|
|
|
x = x.view(x.shape[0], -1) |
|
x = self.fc1(x) |
|
|
|
|
|
return x |
|
|
|
def forward(self, img, txt, action, hx): |
|
self.Stage1.hx = hx |
|
Q = self.get_Q(img, txt, action) |
|
return Q |
|
|
|
|
|
class Retriever_v2(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever_v2, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
self.conv1 = nn.Conv1d(in_channels=22, out_channels=64, kernel_size=5, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv4 = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.fc1 = nn.Linear(in_features=516, out_features=20, bias=True) |
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action): |
|
env_feature, hx = self.get_env_feature(img, txt) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
x = torch.cat((env_feature, action, hx), dim=1) |
|
|
|
x = self.relu(self.conv1(x)) |
|
x = self.relu(self.conv2(x)) |
|
x = self.relu(self.conv3(x)) |
|
x = self.relu(self.conv4(x)) |
|
|
|
x = x.view(x.shape[0], -1) |
|
x = self.fc1(x) |
|
|
|
|
|
return x |
|
|
|
def forward(self, img, txt, action, hx): |
|
|
|
self.Stage1.hx = hx |
|
img_embed = self.img_embed(img) |
|
txt_embed = self.txt_embed(self.tokenize(txt).to(self.device)) |
|
|
|
c = [] |
|
for i in range(action.shape[1]): |
|
action_i = self.img_embed(action[:, i]) |
|
c.append(action_i) |
|
c = torch.stack(c, dim=1) |
|
Q = self.get_Q(img_embed, txt_embed, c) |
|
return Q |
|
|
|
class Retriever_v3(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever_v3, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
self.conv1 = nn.Conv1d(in_channels=52, out_channels=64, kernel_size=5, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv4 = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.bn1 = nn.BatchNorm1d(64) |
|
self.bn2 = nn.BatchNorm1d(32) |
|
self.bn3 = nn.BatchNorm1d(16) |
|
self.bn4 = nn.BatchNorm1d(1) |
|
self.fc1 = nn.Linear(in_features=516, out_features=50, bias=True) |
|
self.bn5 = nn.BatchNorm1d(50) |
|
|
|
|
|
|
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
img_embed = img_embed / img_embed.norm(dim=-1, keepdim=True) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
txt_embed = txt_embed / txt_embed.norm(dim=-1, keepdim=True) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt, reward): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt, reward) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action, reward): |
|
env_feature, hx = self.get_env_feature(img, txt, reward) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
x = torch.cat((env_feature, action, hx), dim=1) |
|
|
|
x = self.relu(self.bn1(self.conv1(x))) |
|
x = self.relu(self.bn2(self.conv2(x))) |
|
x = self.relu(self.bn3(self.conv3(x))) |
|
x = self.relu(self.bn4(self.conv4(x))) |
|
|
|
x = x.view(x.shape[0], -1) |
|
x = self.fc1(x) |
|
x = self.bn5(x) |
|
|
|
x = self.softmax(x) |
|
return x |
|
|
|
def forward(self, img, txt, action, hx, r): |
|
self.Stage1.hx = hx |
|
Q = self.get_Q(img, txt, action, r) |
|
return Q |
|
|
|
|
|
class Retriever_v4(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever_v4, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
self.conv1 = nn.Conv1d(in_channels=51, out_channels=64, kernel_size=5, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv4 = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.fc1 = nn.Linear(in_features=516, out_features=50, bias=True) |
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action): |
|
env_feature, hx = self.get_env_feature(img, txt) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
x = torch.cat((env_feature, action), dim=1) |
|
|
|
x = self.relu(self.conv1(x)) |
|
x = self.relu(self.conv2(x)) |
|
x = self.relu(self.conv3(x)) |
|
x = self.relu(self.conv4(x)) |
|
|
|
x = x.view(x.shape[0], -1) |
|
x = self.fc1(x) |
|
|
|
x = self.softmax(x) |
|
return x |
|
|
|
def forward(self, img, txt, action, hx): |
|
self.Stage1.hx = hx |
|
Q = self.get_Q(img, txt, action) |
|
return Q |
|
|
|
class Retriever_v5(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever_v5, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
|
|
|
|
|
|
|
|
self.fc01 = nn.Linear(in_features=512 * 2, out_features=512, bias=True) |
|
self.fc02 = nn.Linear(in_features=512, out_features=256, bias=True) |
|
self.fc03 = nn.Linear(in_features=256, out_features=64, bias=True) |
|
self.fc04 = nn.Linear(in_features=64, out_features=1, bias=True) |
|
|
|
self.fc2 = nn.Linear(in_features=512, out_features=512, bias=True) |
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
img_embed = img_embed / img_embed.norm(dim=-1, keepdim=True) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
txt_embed = txt_embed / txt_embed.norm(dim=-1, keepdim=True) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt, reward): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt, reward) |
|
x = x / x.norm(dim=-1, keepdim=True) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action, reward): |
|
env_feature, hx = self.get_env_feature(img, txt, reward) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
action = action.squeeze(dim=1) |
|
action = self.fc2(action) |
|
action = action.unsqueeze(dim=1) |
|
|
|
|
|
x = torch.cat((env_feature, action), dim=2) |
|
|
|
|
|
x = self.fc01(x) |
|
|
|
|
|
x = self.fc02(x) |
|
x = self.fc03(x) |
|
x = self.fc04(x) |
|
x = x.squeeze(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x |
|
|
|
def forward(self, img, txt, action, hx, r): |
|
self.Stage1.hx = hx |
|
Q = self.get_Q(img, txt, action, r) |
|
return Q |
|
|
|
class Retriever_v6(nn.Module): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
super(Retriever_v6, self).__init__() |
|
self.sentence_clip_model2 = sentence_clip_model2 |
|
self.sentence_clip_preprocess2 = sentence_clip_preprocess2 |
|
self.tokenize = clip.tokenize |
|
self.Stage1 = Stage1(hid_dim=512) |
|
self.conv1 = nn.Conv1d(in_channels=52, out_channels=64, kernel_size=5, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=2, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.conv4 = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros') |
|
self.fc1 = nn.Linear(in_features=516, out_features=50, bias=True) |
|
self.relu = nn.ReLU() |
|
self.softmax = nn.Softmax(dim=1) |
|
self.device = device |
|
def img_embed(self, img): |
|
img_embed = self.sentence_clip_model2.encode_image(img) |
|
return img_embed |
|
|
|
def txt_embed(self, txt_token): |
|
txt_embed = self.sentence_clip_model2.encode_text(txt_token) |
|
return txt_embed |
|
|
|
def get_env_feature(self, img, txt): |
|
hx = self.Stage1.hx |
|
x = self.Stage1(img, txt) |
|
return x, hx |
|
|
|
def init_hid(self, batch_size=1): |
|
self.Stage1.init_hid(batch_size=batch_size) |
|
return |
|
|
|
def detach_hid(self): |
|
self.Stage1.detach_hid() |
|
return |
|
|
|
def get_Q(self, img, txt, action): |
|
env_feature, hx = self.get_env_feature(img, txt) |
|
env_feature = env_feature.unsqueeze(dim=1) |
|
hx = hx.unsqueeze(dim=1) |
|
|
|
x = torch.cat((env_feature, action, hx), dim=1) |
|
|
|
x = self.relu(self.conv1(x)) |
|
x = self.relu(self.conv2(x)) |
|
x = self.relu(self.conv3(x)) |
|
x = self.relu(self.conv4(x)) |
|
|
|
x = x.view(x.shape[0], -1) |
|
x = self.fc1(x) |
|
|
|
x = self.softmax(x) |
|
return x |
|
|
|
def forward(self, img, txt, action, hx): |
|
self.Stage1.hx = hx |
|
Q = self.get_Q(img, txt, action) |
|
return Q |
|
|
|
|
|
|
|
MEMORY_CAPACITY = 100 |
|
N_STATES = 4 |
|
GAMMA = 0.95 |
|
|
|
class DQN(): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
self.eval_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.target_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.num_train = 0 |
|
|
|
self.memory = [] |
|
self.memory_counter = 0 |
|
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.01) |
|
self.loss_func = nn.MSELoss() |
|
|
|
def store_transition(self, g, f, c, hx, a, r, g_, f_, c_, hx_, batch_size): |
|
for i in range(batch_size): |
|
if r[i] == -10000: |
|
continue |
|
if r[i] > 100: |
|
r[i] = 100 |
|
self.memory.append((g[i], f[i], c[i], hx[i], a[i], r[i], g_[i], f_[i], c_[i], hx_[i])) |
|
if len(self.memory) > MEMORY_CAPACITY: |
|
self.memory.pop(0) |
|
self.memory_counter += 1 |
|
|
|
def learn(self, batch_size, device=None): |
|
|
|
if self.memory_counter > MEMORY_CAPACITY: |
|
sample_index = np.random.choice(MEMORY_CAPACITY, batch_size) |
|
else: |
|
sample_index = np.random.choice(self.memory_counter, batch_size) |
|
|
|
|
|
b_g, b_f, b_c, b_hx, b_a, b_r, b_g_, b_f_, b_c_, b_hx_ = self.memory[sample_index[0]] |
|
b_g = b_g.unsqueeze(dim=0) |
|
b_f = b_f.unsqueeze(dim=0) |
|
b_c = b_c.unsqueeze(dim=0) |
|
b_hx = b_hx.unsqueeze(dim=0) |
|
b_a = b_a.unsqueeze(dim=0) |
|
b_r = b_r.unsqueeze(dim=0) |
|
b_g_ = b_g_.unsqueeze(dim=0) |
|
b_f_ = b_f_.unsqueeze(dim=0) |
|
b_c_ = b_c_.unsqueeze(dim=0) |
|
b_hx_ = b_hx_.unsqueeze(dim=0) |
|
for i in range(1, len(sample_index)): |
|
g, f, c, hx, a, r, g_, f_, c_, hx_ = self.memory[sample_index[i]] |
|
b_g = torch.cat((b_g, g.unsqueeze(dim=0)), dim=0) |
|
b_f = torch.cat((b_f, f.unsqueeze(dim=0)), dim=0) |
|
b_c = torch.cat((b_c, c.unsqueeze(dim=0)), dim=0) |
|
b_hx = torch.cat((b_hx, hx.unsqueeze(dim=0)), dim=0) |
|
b_a = torch.cat((b_a, a.unsqueeze(dim=0)), dim=0) |
|
b_r = torch.cat((b_r, r.unsqueeze(dim=0)), dim=0) |
|
b_g_ = torch.cat((b_g_, g_.unsqueeze(dim=0)), dim=0) |
|
b_f_ = torch.cat((b_f_, f_.unsqueeze(dim=0)), dim=0) |
|
b_c_ = torch.cat((b_c_, c_.unsqueeze(dim=0)), dim=0) |
|
b_hx_ = torch.cat((b_hx_, hx_.unsqueeze(dim=0)), dim=0) |
|
|
|
b_a = b_a.unsqueeze(dim=1) |
|
b_r = b_r.unsqueeze(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
b_c1 = torch.zeros(b_c.shape[0], 1, b_c.shape[2]).to(device) |
|
for i in range(b_a.shape[0]): |
|
b_c1[i][0] = b_c[i][b_a[i][0]].to(device) |
|
q_eval = self.eval_net(b_g, b_f.detach(), b_c1, b_hx.detach()) |
|
q_next = torch.zeros((batch_size, len(b_c_[0]))).to(device) |
|
for i in range(len(b_c_[0])): |
|
q_temp = self.target_net(b_g_, b_f_, b_c_[:, i, :].unsqueeze(1), b_hx_).detach() |
|
q_next[:, i] = q_temp.squeeze(dim=1) |
|
flag = 0 |
|
for j in range(len(b_c_)): |
|
if b_r[j] >= 95: |
|
q_next[j, :] = torch.zeros((1,len(b_c_[0]))).to(device) |
|
flag = 1 |
|
q_target = b_r + GAMMA * q_next.max(1)[0].view(batch_size, 1) |
|
|
|
loss = self.loss_func(q_eval.float(), q_target.float()) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss |
|
|
|
|
|
class DQN_v1(): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
self.eval_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.target_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.num_train = 0 |
|
|
|
self.memory = [] |
|
self.memory_counter = 0 |
|
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.01) |
|
self.loss_func = nn.MSELoss() |
|
|
|
def store_transition(self, g, f, c, hx, a, r, g_, f_, c_, hx_, batch_size): |
|
for i in range(batch_size): |
|
if r[i] == -10000: |
|
continue |
|
self.memory.append((g[i], f[i], c[i], hx[i], a[i], r[i], g_[i], f_[i], c_[i], hx_[i])) |
|
if len(self.memory) > MEMORY_CAPACITY: |
|
self.memory.pop(0) |
|
self.memory_counter += 1 |
|
|
|
def learn(self, batch_size, device=None): |
|
|
|
if self.memory_counter > MEMORY_CAPACITY: |
|
sample_index = np.random.choice(MEMORY_CAPACITY, batch_size) |
|
else: |
|
sample_index = np.random.choice(self.memory_counter, batch_size) |
|
|
|
|
|
b_g, b_f, b_c, b_hx, b_a, b_r, b_g_, b_f_, b_c_, b_hx_ = self.memory[sample_index[0]] |
|
b_g = b_g.unsqueeze(dim=0) |
|
b_fl = [] |
|
b_fl.append(b_f) |
|
|
|
b_c = b_c.unsqueeze(dim=0) |
|
b_hx = b_hx.unsqueeze(dim=0) |
|
b_a = b_a.unsqueeze(dim=0) |
|
b_r = b_r.unsqueeze(dim=0) |
|
b_g_ = b_g_.unsqueeze(dim=0) |
|
b_fl_ = [] |
|
b_fl_.append(b_f_) |
|
|
|
b_c_ = b_c_.unsqueeze(dim=0) |
|
b_hx_ = b_hx_.unsqueeze(dim=0) |
|
for i in range(1, len(sample_index)): |
|
g, f, c, hx, a, r, g_, f_, c_, hx_ = self.memory[sample_index[i]] |
|
b_g = torch.cat((b_g, g.unsqueeze(dim=0)), dim=0) |
|
b_fl.append(f) |
|
|
|
b_c = torch.cat((b_c, c.unsqueeze(dim=0)), dim=0) |
|
b_hx = torch.cat((b_hx, hx.unsqueeze(dim=0)), dim=0) |
|
b_a = torch.cat((b_a, a.unsqueeze(dim=0)), dim=0) |
|
b_r = torch.cat((b_r, r.unsqueeze(dim=0)), dim=0) |
|
b_g_ = torch.cat((b_g_, g_.unsqueeze(dim=0)), dim=0) |
|
b_fl_.append(f_) |
|
|
|
b_c_ = torch.cat((b_c_, c_.unsqueeze(dim=0)), dim=0) |
|
b_hx_ = torch.cat((b_hx_, hx_.unsqueeze(dim=0)), dim=0) |
|
|
|
b_a = b_a.unsqueeze(dim=1) |
|
b_r = b_r.unsqueeze(dim=1) |
|
|
|
if self.num_train > 5000: |
|
self.target_net.load_state_dict(self.eval_net.state_dict()) |
|
self.num_train = 0 |
|
self.num_train += 1 |
|
|
|
q_eval = self.eval_net(b_g, b_fl, b_c, b_hx.detach()).gather(1, b_a) |
|
with torch.no_grad(): |
|
q_next = self.target_net(b_g_, b_fl_, b_c_, b_hx_).detach() |
|
q_next = self.target_net(b_g_, b_fl_, b_c_, b_hx_).detach() |
|
q_target = b_r + GAMMA * q_next.max(1)[0].view(batch_size, 1) |
|
|
|
loss = self.loss_func(q_eval.float(), q_target.float()) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss |
|
|
|
|
|
class DQN_v2(): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
self.eval_net = Retriever(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.target_net = Retriever(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.actor_net = Retriever_v3(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.num_train = 0 |
|
|
|
self.memory = [] |
|
self.memory_counter = 0 |
|
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.001) |
|
self.critic_optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.001) |
|
self.loss_func = nn.MSELoss() |
|
|
|
def store_transition(self, g, f, c, hx, a, r, g_, f_, c_, hx_, batch_size): |
|
for i in range(batch_size): |
|
if r[i] == -10000: |
|
continue |
|
self.memory.append((g[i], f[i], c[i], hx[i], a[i], r[i], g_[i], f_[i], c_[i], hx_[i])) |
|
if len(self.memory) > MEMORY_CAPACITY: |
|
self.memory.pop(0) |
|
self.memory_counter += 1 |
|
|
|
def learn(self, batch_size): |
|
|
|
if self.memory_counter > MEMORY_CAPACITY: |
|
sample_index = np.random.choice(MEMORY_CAPACITY, batch_size) |
|
else: |
|
sample_index = np.random.choice(self.memory_counter, batch_size) |
|
|
|
|
|
b_g, b_f, b_c, b_hx, b_a, b_r, b_g_, b_f_, b_c_, b_hx_ = self.memory[sample_index[0]] |
|
b_g = b_g.unsqueeze(dim=0) |
|
b_f = b_f.unsqueeze(dim=0) |
|
b_c = b_c.unsqueeze(dim=0) |
|
b_hx = b_hx.unsqueeze(dim=0) |
|
b_a = b_a.unsqueeze(dim=0) |
|
b_r = b_r.unsqueeze(dim=0) |
|
b_g_ = b_g_.unsqueeze(dim=0) |
|
b_f_ = b_f_.unsqueeze(dim=0) |
|
b_c_ = b_c_.unsqueeze(dim=0) |
|
b_hx_ = b_hx_.unsqueeze(dim=0) |
|
for i in range(1, len(sample_index)): |
|
g, f, c, hx, a, r, g_, f_, c_, hx_ = self.memory[sample_index[i]] |
|
b_g = torch.cat((b_g, g.unsqueeze(dim=0)), dim=0) |
|
b_f = torch.cat((b_f, f.unsqueeze(dim=0)), dim=0) |
|
b_c = torch.cat((b_c, c.unsqueeze(dim=0)), dim=0) |
|
b_hx = torch.cat((b_hx, hx.unsqueeze(dim=0)), dim=0) |
|
b_a = torch.cat((b_a, a.unsqueeze(dim=0)), dim=0) |
|
b_r = torch.cat((b_r, r.unsqueeze(dim=0)), dim=0) |
|
b_g_ = torch.cat((b_g_, g_.unsqueeze(dim=0)), dim=0) |
|
b_f_ = torch.cat((b_f_, f_.unsqueeze(dim=0)), dim=0) |
|
b_c_ = torch.cat((b_c_, c_.unsqueeze(dim=0)), dim=0) |
|
b_hx_ = torch.cat((b_hx_, hx_.unsqueeze(dim=0)), dim=0) |
|
|
|
b_a = b_a.unsqueeze(dim=1) |
|
b_r = b_r.unsqueeze(dim=1) |
|
|
|
if self.num_train > 50: |
|
self.target_net.load_state_dict(self.eval_net.state_dict()) |
|
self.num_train = 0 |
|
self.num_train += 1 |
|
|
|
q_eval = self.eval_net(b_g, b_f.detach(), b_c, b_hx.detach()).gather(1, b_a) |
|
q_next = self.target_net(b_g_, b_f_, b_c_, b_hx_).detach() |
|
q_target = b_r + GAMMA * q_next.max(1)[0].view(batch_size, 1) |
|
|
|
critic_loss = torch.mean(F.mse_loss(q_eval.float(), q_target.float())) |
|
|
|
log_probs = self.actor_net(b_g, b_f.detach(), b_c, b_hx.detach()) |
|
delta = q_target - q_eval |
|
|
|
actor_loss = torch.mean(-log_probs * delta.detach()).float() |
|
|
|
self.actor_optimizer.zero_grad() |
|
self.critic_optimizer.zero_grad() |
|
actor_loss.backward() |
|
critic_loss.backward() |
|
self.actor_optimizer.step() |
|
self.critic_optimizer.step() |
|
|
|
return actor_loss, critic_loss |
|
|
|
|
|
class DQN_v3(): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device, MULTI_GPU=False, device_ids=None): |
|
self.eval_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.target_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.actor_net = Retriever_v3(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.num_train = 0 |
|
self.device = device |
|
self.MULTI_GPU = MULTI_GPU |
|
self.device_ids = device_ids |
|
|
|
self.memory = [] |
|
self.memory_counter = 0 |
|
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.0001) |
|
self.critic_optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.001) |
|
self.loss_func = nn.MSELoss() |
|
|
|
|
|
def store_transition(self, g, f, c, hx, a, r, g_, f_, c_, hx_, batch_size, net_mem): |
|
for i in range(batch_size): |
|
if r[i] <= -2000: |
|
continue |
|
else: |
|
g_tmp = copy.deepcopy(g[i]) |
|
f_tmp = copy.deepcopy(f[i]) |
|
c_tmp = copy.deepcopy(c[i]) |
|
hx_tmp = copy.deepcopy(hx[i]) |
|
a_tmp = copy.deepcopy(a[i]) |
|
reward_temp = copy.deepcopy(r[i]) |
|
g_tmp_ = copy.deepcopy(g_[i]) |
|
f_tmp_ = copy.deepcopy(f_[i]) |
|
c_tmp_ = copy.deepcopy(c_[i]) |
|
hx_tmp_ = copy.deepcopy(hx_[i]) |
|
self.memory.append((g_tmp, f_tmp, c_tmp, hx_tmp, a_tmp, reward_temp, g_tmp_, f_tmp_, c_tmp_, hx_tmp_)) |
|
if len(self.memory) > net_mem: |
|
self.memory.pop(0) |
|
self.memory_counter = len(self.memory) |
|
def learn(self, batch_size, device=None): |
|
|
|
if self.memory_counter < batch_size: |
|
return 0, 0 |
|
else: |
|
sample_index = np.random.choice(len(self.memory), batch_size, replace=False) |
|
sample_index = torch.LongTensor(sample_index).to(device) |
|
g, f, c, hx, a, r, g_, f_, c_, hx_ = zip(*[self.memory[i] for i in sample_index]) |
|
b_g = torch.stack(g) |
|
b_f = torch.stack(f) |
|
b_c = torch.stack(c) |
|
b_hx = torch.stack(hx) |
|
b_a = torch.stack(a).unsqueeze(1) |
|
b_r = torch.stack(r) |
|
b_g_ = torch.stack(g_) |
|
b_f_ = torch.stack(f_) |
|
b_c_ = torch.stack(c_) |
|
b_hx_ = torch.stack(hx_) |
|
|
|
if self.num_train > 0: |
|
self.target_net.load_state_dict(self.eval_net.state_dict()) |
|
self.num_train = 0 |
|
self.num_train += 1 |
|
b_c1 = torch.zeros(b_c.shape[0], 1, b_c.shape[2]).to(device) |
|
for i in range(b_a.shape[0]): |
|
b_c1[i][0] = b_c[i][b_a[i][0]].to(device) |
|
q_eval = self.eval_net(b_g, b_f.detach(), b_c1, b_hx.detach(), b_r) |
|
q_next = torch.zeros((batch_size, len(b_c_[0]))).to(device) |
|
for i in range(len(b_c_[0])): |
|
q_temp = self.target_net(b_g_, b_f_, b_c_[:, i, :].unsqueeze(1), b_hx_, b_r).detach() |
|
q_next[:, i] = q_temp.squeeze(dim=1) |
|
for j in range(len(b_c_)): |
|
if b_r[j] >= 900: |
|
q_next[j, :] = torch.zeros((1, len(b_c_[0]))).to(device) |
|
if b_r[j] <= -900: |
|
q_next[j, :] = torch.zeros((1, len(b_c_[0]))).to(device) |
|
q_target = b_r + 0.995 * q_next.max(1)[0].view(batch_size, 1) |
|
print("q_eval, q_target", q_eval[0], q_target[0], b_r[0], q_eval[0] - q_target[0]) |
|
critic_loss = self.loss_func(q_eval.float(), q_target.float()) |
|
log_probs = self.actor_net(b_g, b_f.detach(), b_c, b_hx.detach(), b_r) |
|
|
|
log_probs_temp = torch.zeros((batch_size, 1)).to(device) |
|
for i in range(b_a.shape[0]): |
|
log_probs_temp[i][0] = log_probs[i][b_a[i][0]].to(device) |
|
if log_probs_temp[i][0] > 0.9: |
|
log_probs_temp[i][0] -= 0.1 |
|
|
|
|
|
delta = q_target - q_eval |
|
for j in range(len(b_c_)): |
|
if b_r[j] >= 100: |
|
delta[j] = b_r[j].to(device) |
|
if b_r[j] <= -900: |
|
delta[j] = b_r[j].to(device) |
|
delta = delta.detach() |
|
b_r = b_r.detach() |
|
|
|
|
|
|
|
actor_loss = torch.mean( - torch.log(log_probs_temp) * delta).float() |
|
|
|
self.actor_optimizer.zero_grad() |
|
actor_loss.backward() |
|
self.actor_optimizer.step() |
|
|
|
self.critic_optimizer.zero_grad() |
|
critic_loss.backward() |
|
self.critic_optimizer.step() |
|
|
|
return actor_loss, critic_loss |
|
|
|
|
|
|
|
class A_C_txt(): |
|
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device): |
|
self.eval_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.target_net = Retriever_v5(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.actor_net = Retriever_v3(sentence_clip_model2, sentence_clip_preprocess2, device).to(device) |
|
self.num_train = 0 |
|
|
|
self.memory = [] |
|
self.memory_counter = 0 |
|
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.0001) |
|
self.critic_optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=0.0001) |
|
self.loss_func = nn.MSELoss() |
|
|
|
def store_transition(self, g, f, c, hx, a, r, g_, f_, c_, hx_, batch_size, txt_token): |
|
for i in range(batch_size): |
|
if r[i] <= -1000: |
|
continue |
|
self.memory.append((g[i], f[i], c[i], hx[i], a[i], r[i], g_[i], f_[i], c_[i], hx_[i], txt_token[i])) |
|
if len(self.memory) > MEMORY_CAPACITY: |
|
self.memory.pop(0) |
|
self.memory_counter += 1 |
|
|
|
def learn(self, batch_size, device=None): |
|
|
|
if self.memory_counter > MEMORY_CAPACITY: |
|
sample_index = np.random.choice(MEMORY_CAPACITY, batch_size) |
|
else: |
|
sample_index = np.random.choice(self.memory_counter, batch_size) |
|
|
|
|
|
b_g, b_f, b_c, b_hx, b_a, b_r, b_g_, b_f_, b_c_, b_hx_, txt_token = self.memory[sample_index[0]] |
|
b_g = b_g.unsqueeze(dim=0) |
|
b_f = b_f.unsqueeze(dim=0) |
|
b_c = b_c.unsqueeze(dim=0) |
|
b_hx = b_hx.unsqueeze(dim=0) |
|
b_a = b_a.unsqueeze(dim=0) |
|
b_r = b_r.unsqueeze(dim=0) |
|
b_g_ = b_g_.unsqueeze(dim=0) |
|
b_f_ = b_f_.unsqueeze(dim=0) |
|
b_c_ = b_c_.unsqueeze(dim=0) |
|
b_hx_ = b_hx_.unsqueeze(dim=0) |
|
b_txt_token = txt_token.unsqueeze(dim=0) |
|
for i in range(1, len(sample_index)): |
|
g, f, c, hx, a, r, g_, f_, c_, hx_ = self.memory[sample_index[i]] |
|
b_g = torch.cat((b_g, g.unsqueeze(dim=0)), dim=0) |
|
b_f = torch.cat((b_f, f.unsqueeze(dim=0)), dim=0) |
|
b_c = torch.cat((b_c, c.unsqueeze(dim=0)), dim=0) |
|
b_hx = torch.cat((b_hx, hx.unsqueeze(dim=0)), dim=0) |
|
b_a = torch.cat((b_a, a.unsqueeze(dim=0)), dim=0) |
|
b_r = torch.cat((b_r, r.unsqueeze(dim=0)), dim=0) |
|
b_g_ = torch.cat((b_g_, g_.unsqueeze(dim=0)), dim=0) |
|
b_f_ = torch.cat((b_f_, f_.unsqueeze(dim=0)), dim=0) |
|
b_c_ = torch.cat((b_c_, c_.unsqueeze(dim=0)), dim=0) |
|
b_hx_ = torch.cat((b_hx_, hx_.unsqueeze(dim=0)), dim=0) |
|
|
|
|
|
b_a = b_a.unsqueeze(dim=1) |
|
|
|
|
|
if self.num_train > 0: |
|
self.target_net.load_state_dict(self.eval_net.state_dict()) |
|
self.num_train = 0 |
|
self.num_train += 1 |
|
|
|
b_c1 = torch.zeros(b_c.shape[0], 1, b_c.shape[2]).to(device) |
|
for i in range(b_a.shape[0]): |
|
b_c1[i][0] = b_c[i][b_a[i][0]].to(device) |
|
q_eval = self.eval_net(b_g, b_f.detach(), b_c1, b_hx.detach(), b_r) |
|
q_next = torch.zeros((batch_size, len(b_c_[0]))).to(device) |
|
for i in range(len(b_c_[0])): |
|
q_temp = self.target_net(b_g_, b_f_, b_c_[:, i, :].unsqueeze(1), b_hx_, b_r).detach() |
|
q_next[:, i] = q_temp.squeeze(dim=1) |
|
for j in range(len(b_c_)): |
|
if b_r[j] >= 1000: |
|
q_next[j, :] = torch.zeros((1, len(b_c_[0]))).to(device) |
|
q_target = b_r + GAMMA * q_next.max(1)[0].view(batch_size, 1) |
|
|
|
critic_loss = torch.mean(F.mse_loss(q_eval.float(), q_target.float())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
log_probs = self.actor_net(b_g, b_f.detach(), b_c, b_hx.detach(), b_r) |
|
delta = q_target - q_eval |
|
|
|
actor_loss = torch.mean( - log_probs * delta.detach()).float() |
|
|
|
self.actor_optimizer.zero_grad() |
|
actor_loss.backward() |
|
self.actor_optimizer.step() |
|
|
|
self.critic_optimizer.zero_grad() |
|
critic_loss.backward() |
|
self.critic_optimizer.step() |
|
|
|
return actor_loss, critic_loss |
|
|
|
|
|
|
|
|
|
|