| import numpy as np
|
|
|
| from model import SdA
|
| from config import *
|
| import torch.nn as nn
|
| import torch
|
| import time
|
| import logging
|
| import torch.nn.functional as F
|
| def train1(train_pairs, model, criterion, optimizer, epoch, args):
|
| if epoch % 10 == 0:
|
| logging.info("=======> Train epoch: {}/{}".format(epoch, args.epochs))
|
| model.train()
|
| time0 = time.time()
|
| loss_value = 0
|
| x0,x1=torch.from_numpy(train_pairs[0]).float(),torch.from_numpy(train_pairs[1]).float()
|
| x0, x1 = x0.to(args.gpu), x1.to(args.gpu)
|
|
|
| try:
|
| h0, h1, d0, d1 = model(x0, x1)
|
| except:
|
| print("error raise in batch",epoch)
|
|
|
|
|
| loss = criterion(x0, d0)
|
| loss += criterion(x1, d1)
|
| loss += model.regularization_loss()
|
| loss_value += loss.item()
|
| if epoch != 0:
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
| epoch_time = time.time() - time0
|
|
|
| return h0 , h1,epoch_time
|
| def pretrain(train_pairs, args):
|
| model = SdA(config).to(args.gpu)
|
| criterion = nn.MSELoss().to(args.gpu)
|
| optimizer = torch.optim.Adam(model.parameters(), lr=args.learn_rate)
|
|
|
| for i in range(0, args.epochs + 1):
|
| if i == 0:
|
| with torch.no_grad():
|
| h0, h1, epoch_time = train1(train_pairs, model, criterion, optimizer, i, args)
|
| else:
|
| h0, h1, epoch_time = train1(train_pairs, model, criterion, optimizer, i, args)
|
| return h0, h1, epoch_time
|
|
|
| def train2(train_loader, model, criterion,optimizer, epoch, args):
|
| model.train()
|
| time0 = time.time()
|
| loss_value = 0
|
| for batch_idx, (x0, x1, labels, real_labels) in enumerate(train_loader):
|
|
|
| x0, x1, labels, real_labels = x0.to(args.gpu), x1.to(args.gpu), labels.to(args.gpu), real_labels.to(args.gpu)
|
| print(np.shape(x0))
|
| try:
|
| h0, h1 = model(x0.view(x0.size()[0], -1), x1.view(x1.size()[0], -1))
|
| except:
|
| print("error raise in batch", batch_idx)
|
|
|
| pair_dist = F.pairwise_distance(h0, h1)
|
|
|
| loss = criterion(pair_dist, labels, args.margin, args.robust, args)
|
|
|
|
|
| loss_value += loss.item()
|
| if epoch != 0:
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
| epoch_time = time.time() - time0
|
| return epoch_time |