lshzhm's picture
Upload 141 files
1991049 verified
from .critic_objectives import*
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim
import math
import copy
############################
# Simple Augmentations #
############################
def permute(x):
# shuffle the sequence order
idx = torch.randperm(x.shape[0])
return x[idx]
def noise(x):
noise = torch.randn(x.shape) * 0.1
return x + noise.to(x.device)
def drop(x):
# drop 20% of the sequences
drop_num = x.shape[0] // 5
x_aug = torch.clone(x)
drop_idxs = np.random.choice(x.shape[0], drop_num, replace=False)
x_aug[drop_idxs] = 0.0
return x_aug
def mixup(x, alpha=1.0):
indices = torch.randperm(x.shape[0])
lam = np.random.beta(alpha, alpha)
aug_x = x * lam + x[indices] * (1 - lam)
return aug_x
def identity(x):
return x
def augment(x_batch):
v1 = x_batch
v2 = torch.clone(v1)
transforms = [permute, noise, drop, identity]
for i in range(x_batch.shape[0]):
t_idxs = np.random.choice(4, 2, replace=False)
t1 = transforms[t_idxs[0]]
t2 = transforms[t_idxs[1]]
v1[i] = t1(v1[i])
v2[i] = t2(v2[i])
return v1, v2
# return one augmented instance
def augment_single(x_batch):
v1 = x_batch
v2 = torch.clone(v1)
transforms = [permute, noise, drop, identity]
for i in range(x_batch.shape[0]):
t_idxs = np.random.choice(4, 1, replace=False)
t = transforms[t_idxs[0]]
v2[i] = t(v2[i])
return v2
def augment_embed_single(x_batch):
v1 = x_batch
v2 = torch.clone(v1)
transforms = [noise, mixup, identity]
t_idxs = np.random.choice(3, 1, replace=False)
t = transforms[t_idxs[0]]
v2 = t(v2)
return v2
def augment_mimic(x_batch):
if x_batch.dim() == 2:
return augment_embed_single(x_batch)
else:
return augment_single(x_batch)
##############
# Models #
##############
def mlp_head(dim_in, feat_dim):
return nn.Sequential(
nn.Linear(dim_in, dim_in),
nn.ReLU(inplace=True),
nn.Linear(dim_in, feat_dim)
)
class SupConModel(nn.Module):
"""backbone + projection head"""
def __init__(self, temperature, encoders, dim_ins, feat_dims, use_label=False, head='mlp'):
super(SupConModel, self).__init__()
self.use_label = use_label
self.encoders = nn.ModuleList(encoders)
if head == 'linear':
self.head1 = nn.Linear(dim_ins[0], feat_dims[0])
self.head2 = nn.Linear(dim_ins[1], feat_dims[1])
elif head == 'mlp':
self.head1 = nn.Sequential(
nn.Linear(dim_ins[0], dim_ins[0]),
nn.ReLU(inplace=True),
nn.Linear(dim_ins[0], feat_dims[0])
)
self.head2 = nn.Sequential(
nn.Linear(dim_ins[1], dim_ins[1]),
nn.ReLU(inplace=True),
nn.Linear(dim_ins[1], feat_dims[1])
)
else:
raise NotImplementedError(
'head not supported: {}'.format(head))
self.critic = SupConLoss(temperature=temperature)
def forward(self, x1, x2, y):
feat1 = self.encoders[0](x1)
feat1 = self.head1(feat1)
feat2 = self.encoders[1](x2)
feat2 = self.head2(feat2)
feat = torch.cat([feat1.unsqueeze(1), feat2.unsqueeze(1)], dim=1)
loss = self.critic(feat, y) if self.use_label else self.critic(feat)
return loss
def get_embedding(self, x1, x2):
return self.encoders[0](x1), self.encoders[1](x2)
class FactorCLSUP(nn.Module):
def __init__(self, encoders, feat_dims, y_ohe_dim, temperature=1, activation='relu', lr=1e-4, ratio=1):
super(FactorCLSUP, self).__init__()
self.critic_hidden_dim = 512
self.critic_layers = 1
self.critic_activation = 'relu'
self.lr = lr
self.ratio = ratio
self.y_ohe_dim = y_ohe_dim
self.temperature = temperature
self.feat_dims = feat_dims
# encoder backbones
####self.backbones = nn.ModuleList(encoders)
# linear projection heads
####self.linears_infonce_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
self.linears_club_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
####self.linears_infonce_x1y = mlp_head(self.feat_dims[0], self.feat_dims[0])
####self.linears_infonce_x2y = mlp_head(self.feat_dims[1], self.feat_dims[1])
####self.linears_infonce_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
####self.linears_club_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
# critics
####self.infonce_x1x2 = InfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.club_x1x2_cond = CLUBInfoNCECritic(self.feat_dims[0] + self.y_ohe_dim, self.feat_dims[1] + self.y_ohe_dim,
self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
####self.infonce_x1y = InfoNCECritic(self.feat_dims[0], 1, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
####self.infonce_x2y = InfoNCECritic(self.feat_dims[1], 1, self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
####self.infonce_x1x2_cond = InfoNCECritic(self.feat_dims[0] + self.y_ohe_dim, self.feat_dims[1] + self.y_ohe_dim,
#### self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
####self.club_x1x2 = CLUBInfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
def ohe(self, y):
N = y.shape[0]
y_ohe = torch.zeros((N, self.y_ohe_dim))
y_ohe[torch.arange(N).long(), y.T[0].long()] = 1
return y_ohe
def forward(self, x1, x2, y):
# Get embeddings
####x1_embed = self.backbones[0](x1)
####x2_embed = self.backbones[1](x2)
x1_embed, x2_embed = x1, x2
x1_embed = F.normalize(x1_embed, dim=-1)
x2_embed = F.normalize(x2_embed, dim=-1)
# Get ohe label
y_ohe = self.ohe(y).cuda()
# Compute losses
####uncond_losses = [self.infonce_x1x2(self.linears_infonce_x1x2[0](x1_embed), self.linears_infonce_x1x2[1](x2_embed)),
#### self.club_x1x2(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)),
#### self.infonce_x1y(self.linears_infonce_x1y(x1_embed), y),
#### self.infonce_x2y(self.linears_infonce_x2y(x2_embed), y)
####]
####
####cond_losses = [self.infonce_x1x2_cond(torch.cat([self.linears_infonce_x1x2_cond[0](x1_embed), y_ohe], dim=1),
#### torch.cat([self.linears_infonce_x1x2_cond[1](x2_embed), y_ohe], dim=1)),
#### self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1),
#### torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1)),
####]
cond_losses = [self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1)),
]
####return sum(uncond_losses) + sum(cond_losses)
return sum(cond_losses)
def learning_loss(self, x1, x2, y):
# Get embeddings
####x1_embed = self.backbones[0](x1)
####x2_embed = self.backbones[1](x2)
x1_embed, x2_embed = x1, x2
x1_embed = F.normalize(x1_embed, dim=-1)
x2_embed = F.normalize(x2_embed, dim=-1)
y_ohe = self.ohe(y).cuda()
# Calculate InfoNCE loss for CLUB-NCE
learning_losses = [self.club_x1x2.learning_loss(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)),
self.club_x1x2_cond.learning_loss(torch.cat([self.linears_club_x1x2_cond[0](x1_embed), y_ohe], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed), y_ohe], dim=1))
]
return sum(learning_losses)
def get_embedding(self, x1, x2):
x1_embed = self.backbones[0](x1)
x2_embed = self.backbones[1](x2)
x1_reps = [self.linears_infonce_x1x2[0](x1_embed),
self.linears_club_x1x2[0](x1_embed),
self.linears_infonce_x1y(x1_embed),
self.linears_infonce_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_embed)]
x2_reps = [self.linears_infonce_x1x2[1](x2_embed),
self.linears_club_x1x2[1](x2_embed),
self.linears_infonce_x2y(x2_embed),
self.linears_infonce_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_embed)]
return torch.cat(x1_reps, dim=1), torch.cat(x2_reps, dim=1)
def get_optims(self):
non_CLUB_params = [self.backbones.parameters(),
self.infonce_x1x2.parameters(),
self.infonce_x1y.parameters(),
self.infonce_x2y.parameters(),
self.infonce_x1x2_cond.parameters(),
self.linears_infonce_x1x2.parameters(),
self.linears_infonce_x1y.parameters(),
self.linears_infonce_x2y.parameters(),
self.linears_infonce_x1x2_cond.parameters(),
self.linears_club_x1x2_cond.parameters(),
self.linears_club_x1x2.parameters()]
CLUB_params = [self.club_x1x2_cond.parameters(),
self.club_x1x2.parameters()]
non_CLUB_optims = [optim.Adam(param, lr=self.lr) for param in non_CLUB_params]
CLUB_optims = [optim.Adam(param, lr=self.lr) for param in CLUB_params]
return non_CLUB_optims, CLUB_optims
class FactorCLSSL(nn.Module):
def __init__(self, encoders, feat_dims, y_ohe_dim, temperature=1, activation='relu', lr=1e-4, ratio=1):
super(FactorCLSSL, self).__init__()
self.critic_hidden_dim = 512
self.critic_layers = 1
self.critic_activation = 'relu'
self.lr = lr
self.ratio = ratio
self.y_ohe_dim = y_ohe_dim
self.temperature = temperature
# encoder backbones
self.feat_dims = feat_dims
####self.backbones = nn.ModuleList(encoders)
# linear projection heads
self.linears_infonce_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
self.linears_club_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
self.linears_infonce_x1y = mlp_head(self.feat_dims[0], self.feat_dims[0])
self.linears_infonce_x2y = mlp_head(self.feat_dims[1], self.feat_dims[1])
self.linears_infonce_x1x2_cond = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
self.linears_club_x1x2 = nn.ModuleList([mlp_head(self.feat_dims[i], self.feat_dims[i]) for i in range(2)])
# critics
self.infonce_x1x2 = InfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.club_x1x2_cond = CLUBInfoNCECritic(self.feat_dims[0]*2, self.feat_dims[1]*2,
self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.infonce_x1y = InfoNCECritic(self.feat_dims[0], self.feat_dims[0], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.infonce_x2y = InfoNCECritic(self.feat_dims[1], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.infonce_x1x2_cond = InfoNCECritic(self.feat_dims[0]*2, self.feat_dims[1]*2,
self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
self.club_x1x2 = CLUBInfoNCECritic(self.feat_dims[0], self.feat_dims[1], self.critic_hidden_dim, self.critic_layers, activation, temperature=temperature)
def ohe(self, y):
N = y.shape[0]
y_ohe = torch.zeros((N, self.y_ohe_dim))
y_ohe[torch.arange(N).long(), y.T[0].long()] = 1
return y_ohe
def forward(self, x1, x2, x1_aug, x2_aug):
# Get embeddings
####x1_embed = self.backbones[0](x1)
####x2_embed = self.backbones[1](x2)
####
####x1_aug_embed = self.backbones[0](x1_aug)
####x2_aug_embed = self.backbones[1](x2_aug)
x1_embed, x2_embed, x1_aug_embed, x2_aug_embed = x1, x2, x1_aug, x2_aug
x1_embed = F.normalize(x1_embed, dim=-1)
x2_embed = F.normalize(x2_embed, dim=-1)
x1_aug_embed = F.normalize(x1_aug_embed, dim=-1)
x2_aug_embed = F.normalize(x2_aug_embed, dim=-1)
#compute losses
uncond_losses = [self.infonce_x1x2(self.linears_infonce_x1x2[0](x1_embed), self.linears_infonce_x1x2[1](x2_embed)),
self.club_x1x2(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)),
self.infonce_x1y(self.linears_infonce_x1y(x1_embed), self.linears_infonce_x1y(x1_aug_embed)),
self.infonce_x2y(self.linears_infonce_x2y(x2_embed), self.linears_infonce_x2y(x2_aug_embed))
]
cond_losses = [self.infonce_x1x2_cond(torch.cat([self.linears_infonce_x1x2_cond[0](x1_embed),
self.linears_infonce_x1x2_cond[0](x1_aug_embed)], dim=1),
torch.cat([self.linears_infonce_x1x2_cond[1](x2_embed),
self.linears_infonce_x1x2_cond[1](x2_aug_embed)], dim=1)),
self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1))
]
return sum(uncond_losses) + sum(cond_losses)
def learning_loss(self, x1, x2, x1_aug, x2_aug):
# Get embeddings
####x1_embed = self.backbones[0](x1)
####x2_embed = self.backbones[1](x2)
####
####x1_aug_embed = self.backbones[0](x1_aug)
####x2_aug_embed = self.backbones[1](x2_aug)
x1_embed, x2_embed, x1_aug_embed, x2_aug_embed = x1, x2, x1_aug, x2_aug
x1_embed = F.normalize(x1_embed, dim=-1)
x2_embed = F.normalize(x2_embed, dim=-1)
x1_aug_embed = F.normalize(x1_aug_embed, dim=-1)
x2_aug_embed = F.normalize(x2_aug_embed, dim=-1)
# Calculate InfoNCE loss for CLUB-NCE
learning_losses = [self.club_x1x2.learning_loss(self.linears_club_x1x2[0](x1_embed), self.linears_club_x1x2[1](x2_embed)),
self.club_x1x2_cond.learning_loss(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1))
]
return sum(learning_losses)
def get_embedding(self, x1, x2):
x1_embed = self.backbones[0](x1)
x2_embed = self.backbones[1](x2)
x1_reps = [self.linears_infonce_x1x2[0](x1_embed),
self.linears_club_x1x2[0](x1_embed),
self.linears_infonce_x1y(x1_embed),
self.linears_infonce_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_embed)]
x2_reps = [self.linears_infonce_x1x2[1](x2_embed),
self.linears_club_x1x2[1](x2_embed),
self.linears_infonce_x2y(x2_embed),
self.linears_infonce_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_embed)]
return torch.cat(x1_reps, dim=1), torch.cat(x2_reps, dim=1)
def get_optims(self):
non_CLUB_params = [self.backbones.parameters(),
self.infonce_x1x2.parameters(),
self.infonce_x1y.parameters(),
self.infonce_x2y.parameters(),
self.infonce_x1x2_cond.parameters(),
self.linears_infonce_x1x2.parameters(),
self.linears_infonce_x1y.parameters(),
self.linears_infonce_x2y.parameters(),
self.linears_infonce_x1x2_cond.parameters(),
self.linears_club_x1x2_cond.parameters(),
self.linears_club_x1x2.parameters()]
CLUB_params = [self.club_x1x2_cond.parameters(),
self.club_x1x2.parameters()]
non_CLUB_optims = [optim.Adam(param, lr=self.lr) for param in non_CLUB_params]
CLUB_optims = [optim.Adam(param, lr=self.lr) for param in CLUB_params]
return non_CLUB_optims, CLUB_optims
########################
# Training Scripts #
########################
# MOSI/MOSEI Training
def mosi_label(y_batch):
res = copy.deepcopy(y_batch)
res[y_batch >= 0] = 1
res[y_batch < 0] = 0
return res
def train_supcon_mosi(model, train_loader, optimizer, modalities=[0,2], num_epoch=100):
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
y_batch = mosi_label(data_batch[3]).float().cuda()
loss = model(x1_batch, x2_batch, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_sup_mosi(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
y_batch = mosi_label(data_batch[3]).float().cuda()
loss = model(x1_batch, x2_batch, y_batch)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_ssl_mosi(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
x1_aug = augment_single(x1_batch)
x2_aug = augment_single(x2_batch)
loss = model(x1_batch, x2_batch, x1_aug, x2_aug)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
# Sarcasm/Humor Training
def sarcasm_label(y_batch):
res = copy.deepcopy(y_batch)
res[y_batch == -1] = 0
return res
def train_supcon_sarcasm(model, train_loader, optimizer, modalities=[0,2], num_epoch=100):
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
y_batch = sarcasm_label(data_batch[3]).float().cuda()
loss = model(x1_batch, x2_batch, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_sup_sarcasm(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
y_batch = sarcasm_label(data_batch[3]).float().cuda()
#loss, losses, ts = model(x_batch, y_batch)
loss = model(x1_batch, x2_batch, y_batch)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_ssl_sarcasm(model, train_loader, modalities=[0,2], num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0][modalities[0]].float().cuda()
x2_batch = data_batch[0][modalities[1]].float().cuda()
x1_aug = augment_single(x1_batch)
x2_aug = augment_single(x2_batch)
loss = model(x1_batch, x2_batch, x1_aug, x2_aug)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
# MIMIC Training
def train_supcon_mimic(model, train_loader, optimizer, num_epoch=100):
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0].float().cuda()
x2_batch = data_batch[1].float().cuda()
y_batch = data_batch[2].unsqueeze(0).T.float().cuda()
loss = model(x1_batch, x2_batch, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_sup_mimic(model, train_loader, num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0].float().cuda()
x2_batch = data_batch[1].float().cuda()
y_batch = data_batch[2].unsqueeze(0).T.float().cuda()
loss = model(x1_batch, x2_batch, y_batch)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, y_batch)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return
def train_ssl_mimic(model, train_loader, num_epoch=50, num_club_iter=1):
non_CLUB_optims, CLUB_optims = model.get_optims()
losses = []
for _iter in range(num_epoch):
for i_batch, data_batch in enumerate(train_loader):
x1_batch = data_batch[0].float().cuda()
x2_batch = data_batch[1].float().cuda()
x1_aug = augment_mimic(x1_batch)
x2_aug = augment_mimic(x2_batch)
loss = model(x1_batch, x2_batch, x1_aug, x2_aug)
losses.append(loss.detach().cpu().numpy())
for optimizer in non_CLUB_optims:
optimizer.zero_grad()
loss.backward()
for optimizer in non_CLUB_optims:
optimizer.step()
for _ in range(num_club_iter):
learning_loss = model.learning_loss(x1_batch, x2_batch, x1_aug, x2_aug)
for optimizer in CLUB_optims:
optimizer.zero_grad()
learning_loss.backward()
for optimizer in CLUB_optims:
optimizer.step()
if i_batch%100 == 0:
print('iter: ', _iter, ' i_batch: ', i_batch, ' loss: ', loss.item())
return