Spaces:
Running
Running
import os | |
import gdown | |
import zipfile | |
import shutil | |
import torch | |
import torch.nn as nn | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
import time | |
import modules.model as model | |
# Download model if not available | |
if os.path.exists('celeba/') == False: | |
url = 'https://drive.google.com/file/d/13vkq4tFCPE8O78KTj84HHM6kBnYkt8gP/view?usp=sharing' | |
output = 'download.zip' | |
gdown.download(url, output, fuzzy=True) | |
with zipfile.ZipFile(output, 'r') as zip_ref: | |
zip_ref.extractall() | |
os.remove(output) | |
shutil.rmtree('__MACOSX') | |
# Set device | |
if torch.backends.mps.is_available(): | |
device = torch.device('mps') | |
device_name = 'Apple Silicon GPU' | |
elif torch.cuda.is_available(): | |
device = torch.device('cuda') | |
device_name = 'CUDA' | |
else: | |
device = torch.device('cpu') | |
device_name = 'CPU' | |
torch.set_default_device(device) | |
print(f'\nDevice: {device_name}') | |
# Define dataset, dataloader and transform | |
imsize = int(128/0.8) | |
batch_size = 10 | |
fivecrop_transform = transforms.Compose([ | |
transforms.Resize([imsize, imsize]), | |
transforms.Grayscale(1), | |
transforms.FiveCrop(int(imsize*0.8)), | |
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), | |
transforms.Normalize(0, 1) | |
]) | |
train_dataset = datasets.CelebA( | |
root='', | |
split='all', | |
target_type='attr', | |
transform=fivecrop_transform, | |
download=True, | |
) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
generator=torch.Generator(device=device) | |
) | |
# Male index | |
factor = 20 | |
# Define model, optimiser and scheduler | |
torch.manual_seed(2687) | |
resnet = model.resnetModel_128() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.SGD( | |
resnet.parameters(), | |
lr=0.01, | |
momentum=0.9, | |
weight_decay=0.001 | |
) | |
scheduler = torch.optim.lr_scheduler.StepLR( | |
optimizer=optimizer, | |
step_size=1, | |
gamma=0.1 | |
) | |
def mins_to_hours(mins): | |
hours = int(mins/60) | |
rem_mins = mins % 60 | |
return hours, rem_mins | |
epochs = 2 | |
train_losses = [] | |
train_accuracy = [] | |
for i in range(epochs): | |
epoch_time = 0 | |
for j, (X_train, y_train) in enumerate(train_loader): | |
batch_start = time.time() | |
X_train = X_train.to(device) | |
y_train = y_train[:, factor] | |
bs, ncrops, c, h, w = X_train.size() | |
y_pred_crops = resnet.forward(X_train.view(-1, c, h, w)) | |
y_pred = y_pred_crops.view(bs, ncrops, -1).mean(1) | |
loss = criterion(y_pred, y_train) | |
predicted = torch.max(y_pred.data, 1)[1] | |
train_batch_accuracy = (predicted == y_train).sum()/len(X_train) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_losses.append(loss.item()) | |
train_accuracy.append(train_batch_accuracy.item()) | |
batch_end = time.time() | |
batch_time = batch_end - batch_start | |
epoch_time += batch_time | |
avg_batch_time = epoch_time/(j+1) | |
batches_remaining = len(train_loader)-(j+1) | |
epoch_mins_remaining = round(batches_remaining*avg_batch_time/60) | |
epoch_time_remaining = mins_to_hours(epoch_mins_remaining) | |
full_epoch = avg_batch_time*len(train_loader) | |
epochs_remaining = epochs-(i+1) | |
rem_epoch_mins_remaining = epoch_mins_remaining+round(full_epoch*epochs_remaining/60) | |
rem_epoch_time_remaining = mins_to_hours(rem_epoch_mins_remaining) | |
if (j+1) % 10 == 0: | |
print(f'\nEpoch: {i+1}/{epochs} | Train Batch: {j+1}/{len(train_loader)}') | |
print(f'Current epoch: {epoch_time_remaining[0]} hours {epoch_time_remaining[1]} minutes') | |
print(f'Remaining epochs: {rem_epoch_time_remaining[0]} hours {rem_epoch_time_remaining[1]} minutes') | |
print(f'Train Loss: {loss}') | |
print(f'Train Accuracy: {train_batch_accuracy}') | |
scheduler.step() | |
trained_model_name = resnet.model_name + '_epoch_' + str(i+1) + '.pt' | |
torch.save( | |
resnet.state_dict(), | |
trained_model_name | |
) |