DRv2 / Model /fur_rl /models /retriever_rl.py
Zhonathon's picture
update all file v1
aa7fb02
raw
history blame contribute delete
No virus
12.5 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Model.CLIP.cn_clip.clip import load_from_name
import cn_clip.clip as clip
from PIL import Image
from torch.autograd import Variable
from Model.fur_rl.models.StageOne import Stage1
from Model.fur_rl.models.transformer import ModelAttn
import copy
import PIL
# v3 带softmax的模型,actor
class Retriever_v3(nn.Module):
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device):
super(Retriever_v3, self).__init__()
self.sentence_clip_model2 = sentence_clip_model2
self.sentence_clip_preprocess2 = sentence_clip_preprocess2
self.tokenize = clip.tokenize
self.Stage1 = Stage1(hid_dim=512)
self.fc02 = nn.Linear(in_features=512, out_features=1, bias=True)
self.fc03 = nn.Linear(in_features=256, out_features=64, bias=True)
self.fc04 = nn.Linear(in_features=64, out_features=1, bias=True)
self.fc12 = nn.Linear(in_features=512, out_features=1, bias=True)
self.fc13 = nn.Linear(in_features=256, out_features=64, bias=True)
self.fc14 = nn.Linear(in_features=64, out_features=1, bias=True)
# self.fc1 = nn.Linear(in_features=516, out_features=1, bias=True)
self.fc2 = nn.Linear(in_features=512, out_features=512, bias=True)
self.bn = nn.BatchNorm1d(num_features=1)
self.Attn0 = ModelAttn()
self.Attn1 = ModelAttn()
self.Attn_his = ModelAttn()
self.norm = nn.LayerNorm(normalized_shape=512)
self.norm50 = nn.LayerNorm(normalized_shape=50)
self.Attn_seq = ModelAttn()
self.fc_seq0 = nn.Linear(in_features=512, out_features=1, bias=True)
self.fc_seq1 = nn.Linear(in_features=512, out_features=1, bias=True)
self.Attn_cross = ModelAttn()
self.fc_q = nn.Linear(in_features=512, out_features=1, bias=True)
self.fc_p = nn.Linear(in_features=512, out_features=1, bias=True)
self.hx_his = None
self.tanh = nn.Tanh()
self.logistic = nn.Sigmoid()
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
self.device = device
def img_embed(self, img):
img_embed = self.sentence_clip_model2.encode_image(img)
img_embed = img_embed / img_embed.norm(dim=-1, keepdim=True)
return img_embed
def txt_embed(self, txt_token):
txt_embed = self.sentence_clip_model2.encode_text(txt_token)
txt_embed = txt_embed / txt_embed.norm(dim=-1, keepdim=True)
return txt_embed
def get_candidates(self, txt, actions_matric, k, ranker):
x = self.Attn_seq(q=txt, k=txt, v=txt) # [batch, 24, 512]
x = self.tanh(self.fc_seq0(x)) # [batch, 24, 1]
x = x.squeeze(dim=2).unsqueeze(dim=1) # [batch, 1, 24]
# for baseline greedy
# x = torch.ones(x.shape, device=self.device) # [batch, 1, 24]
# for baseline
x = torch.bmm(x, actions_matric).squeeze(dim=1) # actions_matric[batch, 24, 2000] x [batch, 2000]
# normalize
x = (x - x.mean(dim=-1, keepdim=True)) / x.std(dim=-1, keepdim=True)
p = self.logistic(x)
p = torch.clamp(p, min=1e-10, max=1 - 1e-10) # [batch, 2000]
topk_score, topk_action = torch.topk(p, k=k, dim=-1) # [batch, k] [batch, k]
candidates = torch.zeros((txt.shape[0], k, 512), device=self.device) # [batch, k, 512]
candidates_pre = torch.zeros((txt.shape[0], k, 3, 224, 224), device=self.device) # [batch, k, 3, 224, 224]
candidates_id = torch.zeros((txt.shape[0], k), device=self.device) # [batch, k]
for i in range(k):
candidates[:, i, :] = ranker.data_emb[topk_action[:, i], :] # [batch, 1, 512]
candidates_pre[:, i, :, :, :] = ranker.data_pre[topk_action[:, i], :, :, :] # [batch, 1, 3, 224, 224]
candidates_id[:, i] = ranker.data_id[topk_action[:, i]] # [batch, 1]
return p, topk_score, topk_action, candidates, candidates_id, candidates_pre
def forward(self, img, txt, actions_matric, ranker, k=50):
# to do: use img and txt to do the fusion
fusion = torch.cat((txt, img), dim=1) # [batch, 25, 512]
p_db, topk_score, topk_action, candidates, candidates_id, candidates_pre = \
self.get_candidates(txt, actions_matric, k=k, ranker=ranker)
x = self.Attn_cross(q=candidates, k=fusion, v=fusion) # [batch, k, 512]
# to do: use MLP replace linear
p = self.softmax(self.fc_p(x)).squeeze(dim=2) # [batch, k]
p = torch.clamp(p, min=1e-10, max=1 - 1e-10)
# to do: use MLP replace linear
q = self.tanh(self.fc_q(x)).squeeze(dim=2) # [batch, k]
p_maxIndex = p.argmax(dim=-1) # [batch]
max_candidates = candidates[torch.arange(candidates.shape[0]), p_maxIndex] # [batch, 512]
max_candidates_id = candidates_id[torch.arange(candidates_id.shape[0]), p_maxIndex] # [batch]
return p, q, max_candidates, max_candidates_id, candidates, candidates_id, candidates_pre, p_db
MEMORY_CAPACITY = 100
N_STATES = 4
GAMMA = 0.95
class DQN_v3():
def __init__(self, sentence_clip_model2, sentence_clip_preprocess2, device, MULTI_GPU=False, device_ids=None):
self.actor_net = Retriever_v3(sentence_clip_model2, sentence_clip_preprocess2, device).to(device)
self.num_train = 0
self.device = device
self.MULTI_GPU = MULTI_GPU
self.device_ids = device_ids
self.memory = [] # initialize memory
self.neg_memory = []
self.memory_counter = 0
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.0001)
self.loss_func1 = nn.MSELoss()
self.loss_func2 = nn.MSELoss()
self.batch_mem = {}
self.is_store = None
def store_transition(self, g, f, d, a, r, g_, f_, d_, t, batch_size, net_mem, success_turn, turn):
if turn == 0:
for i in range(batch_size):
self.batch_mem[i] = []
self.is_store = torch.zeros((batch_size), device=self.device)
for i in range(batch_size):
if r[i] <= -2000:
continue
else:
g_tmp = copy.deepcopy(g[i])
f_tmp = copy.deepcopy(f[i]) # [len, 512]
d_tmp = copy.deepcopy(d[i])
a_tmp = copy.deepcopy(a[i])
reward_temp = copy.deepcopy(r[i])
g_tmp_ = copy.deepcopy(g_[i])
f_tmp_ = copy.deepcopy(f_[i])
d_tmp_ = copy.deepcopy(d_[i])
t_tmp = copy.deepcopy(t[i])
self.batch_mem[i].append((g_tmp, f_tmp, d_tmp, a_tmp, reward_temp,
g_tmp_, f_tmp_, d_tmp_, t_tmp))
if success_turn[i] > 0 and self.is_store[i] == 0:
# print("len", len(self.memory), len(self.batch_mem[i]))
self.memory.append(self.batch_mem[i])
self.is_store[i] = 1
while len(self.memory) > net_mem:
self.memory.pop(0)
self.memory_counter = len(self.memory)
elif turn == 9 and self.is_store[i] == 0:
self.neg_memory.append(self.batch_mem[i])
while len(self.neg_memory) > net_mem:
self.neg_memory.pop(0)
# print("neg", success_turn[i], reward_temp, len(self.neg_memory))
def learn(self, batch_size, device=None, ranker=None, k=50):
success_turn = 0
fail_turn = 0
if len(self.memory) == 0 and len(self.neg_memory) == 0:
return 0,0
if len(self.memory) > 0:
success_turn = 1
batch_mem = self.memory[0]
g, f, d, a, r, g_, f_, d_, t = zip(*[batch_mem[i] for i in range(len(batch_mem))])
b_g = torch.stack(g)
b_f = torch.stack(f)
b_d = torch.stack(d)
b_a = torch.stack(a).unsqueeze(1)
b_reward = torch.stack(r)
b_r = torch.zeros((len(r), 1), device=self.device)
for i in range(len(r)):
if i == 0:
b_r[len(r) - i - 1] = r[len(r) - i - 1]
else:
b_r[len(r) - i - 1] = r[len(r) - i - 1] + 0.8 * b_r[len(r) - i]
b_g_ = torch.stack(g_)
b_f_ = torch.stack(f_)
b_d_ = torch.stack(d_)
b_t = torch.stack(t)
self.memory.pop(0)
if len(self.neg_memory) > 0:
batch_mem = self.neg_memory[0]
g, f, d, a, r, g_, f_, d_, t = zip(*[batch_mem[i] for i in range(len(batch_mem))])
b_g_neg = torch.stack(g)
b_f_neg = torch.stack(f)
b_d_neg = torch.stack(d)
b_a_neg = torch.stack(a).unsqueeze(1)
b_reward_neg = torch.stack(r)
b_r_neg = torch.zeros((len(r), 1), device=self.device)
for i in range(len(r)):
if i == 0:
b_r_neg[len(r) - i - 1] = r[len(r) - i - 1]
else:
b_r_neg[len(r) - i - 1] = r[len(r) - i - 1] + 0.3 * b_r_neg[len(r) - i]
b_g_neg_ = torch.stack(g_)
b_f_neg_ = torch.stack(f_)
b_d_neg_ = torch.stack(d_)
b_t_neg = torch.stack(t)
fail_turn = 1
self.neg_memory.pop(0)
if success_turn > 0 and fail_turn > 0:
b_g = torch.cat((b_g, b_g_neg), 0)
b_f = torch.cat((b_f, b_f_neg), 0)
b_d = torch.cat((b_d, b_d_neg), 0)
b_a = torch.cat((b_a, b_a_neg), 0)
b_r = torch.cat((b_r, b_r_neg), 0)
b_g_ = torch.cat((b_g_, b_g_neg_), 0)
b_f_ = torch.cat((b_f_, b_f_neg_), 0)
b_d_ = torch.cat((b_d_, b_d_neg_), 0)
b_t = torch.cat((b_t, b_t_neg), 0)
b_reward = torch.cat((b_reward, b_reward_neg), 0)
elif success_turn == 0 and fail_turn > 0:
b_g = b_g_neg
b_f = b_f_neg
b_d = b_d_neg
b_a = b_a_neg
b_r = b_r_neg
b_g_ = b_g_neg_
b_f_ = b_f_neg_
b_d_ = b_d_neg_
b_t = b_t_neg
b_reward = b_reward_neg
else:
b_g = b_g
b_f = b_f
b_d = b_d
b_a = b_a
b_r = b_r
b_g_ = b_g_
b_f_ = b_f_
b_d_ = b_d_
b_t = b_t
b_reward = b_reward
## actor loss
# img, txt, actions_matric, ranker, k = 50
p, q, max_candidates, max_candidates_id, candidates, candidates_id, candidates_pre, p_db = \
self.actor_net(b_g, b_f, b_d, ranker, k)
log_probs = torch.zeros(p.shape[0], 1, device=self.device)
q_eval = torch.zeros(q.shape[0], 1, device=self.device)
for i in range(p.shape[0]):
log_probs[i][0] = p[i][b_a[i][0]].to(device)
q_eval[i][0] = q[i][b_a[i][0]].to(device) # 取值范围(-1,1) [batch, k]
q_next = self.actor_net(b_g_, b_f_, b_d_, ranker, k)[1].detach()
q_target = b_r + 0.8 * q_next.max(1)[0].view(b_a.shape[0], 1)
delta = q_target - q_eval
# delta = b_r - q_eval_temp
# print('q_target', q_target[0], 'q_eval', q_eval_temp[0], 'delta', delta[0])
# print("log_probs_temp", max_score, "b_r", b_r, "log(p)", torch.log(log_probs_temp))
actor_loss = -(torch.log(log_probs).t() @ b_r.detach()).float()
actor_loss = 0
# critic_loss = torch.mean(delta ** 2).float()
# critic_loss = torch.mean((b_r/5 - q_eval) ** 2).float()
critic_loss = 0
## supervised_loss
weight_loss = torch.zeros(1, device=self.device)
for i in range(b_reward.shape[0]):
s_loss_tmp = - b_reward[i][0] * torch.log(p_db[i][b_t[i][0].long()])
weight_loss += s_loss_tmp
weight_loss = weight_loss.float()
# weight_loss = 0
## total loss
loss = 1 * actor_loss + 1 * weight_loss
if loss != 0:
self.actor_optimizer.zero_grad()
print("actor_loss", actor_loss, "critic_loss", critic_loss, "weight_loss", weight_loss)
loss.backward()
self.actor_optimizer.step()
return actor_loss, critic_loss