trying-deepfake / train.py
tony133777's picture
new
aae9c6b
import sys, os
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from time import perf_counter
import pickle
from model.config import load_config
from model.genconvit_ed import GenConViTED
from model.genconvit_vae import GenConViTVAE
from dataset.loader import load_data, load_checkpoint
import optparse
config = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_pretrained(pretrained_model_filename):
assert os.path.isfile(
pretrained_model_filename
), "Saved model file does not exist. Exiting."
model, optimizer, start_epoch, min_loss = load_checkpoint(
model, optimizer, filename=pretrained_model_filename
)
# now individually transfer the optimizer parts...
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
return model, optimizer, start_epoch, min_loss
def train_model(
dir_path, mod, num_epochs, pretrained_model_filename, test_model, batch_size
):
print("Loading data...")
dataloaders, dataset_sizes = load_data(dir_path, batch_size)
print("Done.")
if mod == "ed":
from train.train_ed import train, valid
model = GenConViTED(config)
else:
from train.train_vae import train, valid
model = GenConViTVAE(config)
optimizer = optim.Adam(
model.parameters(),
lr=float(config["learning_rate"]),
weight_decay=float(config["weight_decay"]),
)
criterion = nn.CrossEntropyLoss()
criterion.to(device)
mse = nn.MSELoss()
min_val_loss = int(config["min_val_loss"])
scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
if pretrained_model_filename:
model, optimizer, start_epoch, min_loss = load_pretrained(
pretrained_model_filename
)
model.to(device)
torch.manual_seed(1)
train_loss, train_acc, valid_loss, valid_acc = [], [], [], []
since = time.time()
for epoch in range(0, num_epochs):
train_loss, train_acc, epoch_loss = train(
model,
device,
dataloaders["train"],
criterion,
optimizer,
epoch,
train_loss,
train_acc,
mse,
)
valid_loss, valid_acc = valid(
model,
device,
dataloaders["validation"],
criterion,
epoch,
valid_loss,
valid_acc,
mse,
)
scheduler.step()
time_elapsed = time.time() - since
print(
"Training complete in {:.0f}m {:.0f}s".format(
time_elapsed // 60, time_elapsed % 60
)
)
print("\nSaving model...\n")
file_path = os.path.join(
"weight",
f'genconvit_{mod}_{time.strftime("%b_%d_%Y_%H_%M_%S", time.localtime())}',
)
with open(f"{file_path}.pkl", "wb") as f:
pickle.dump([train_loss, train_acc, valid_loss, valid_acc], f)
state = {
"epoch": num_epochs + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"min_loss": epoch_loss,
}
weight = f"{file_path}.pth"
torch.save(state, weight)
print("Done.")
if test_model:
test(model, dataloaders, dataset_sizes, mod, weight)
def test(model, dataloaders, dataset_sizes, mod, weight):
print("\nRunning test...\n")
model.eval()
checkpoint = torch.load(weight, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
_ = model.eval()
Sum = 0
counter = 0
for inputs, labels in dataloaders["test"]:
inputs = inputs.to(device)
labels = labels.to(device)
if mod == "ed":
output = model(inputs).to(device).float()
else:
output = model(inputs)[0].to(device).float()
_, prediction = torch.max(output, 1)
pred_label = labels[prediction]
pred_label = pred_label.detach().cpu().numpy()
main_label = labels.detach().cpu().numpy()
bool_list = list(map(lambda x, y: x == y, pred_label, main_label))
Sum += sum(np.array(bool_list) * 1)
counter += 1
print(f"Pediction: {Sum}/{len(inputs)*counter}")
print(
f'Prediction: {Sum}/{dataset_sizes["test"]} {(Sum / dataset_sizes["test"]) * 100:.2f}%'
)
def gen_parser():
parser = optparse.OptionParser("Train GenConViT model.")
parser.add_option(
"-e",
"--epoch",
type=int,
dest="epoch",
help="Number of epochs used for training the GenConvNextViT model.",
)
parser.add_option("-v", "--version", dest="version", help="Version 0.1.")
parser.add_option("-d", "--dir", dest="dir", help="Training data path.")
parser.add_option(
"-m",
"--model",
dest="model",
help="model ed or model vae, model variant: genconvit (A) ed or genconvit (B) vae.",
)
parser.add_option(
"-p",
"--pretrained",
dest="pretrained",
help="Saved model file name. If you want to continue from the previous trained model.",
)
parser.add_option("-t", "--test", dest="test", help="run test on test dataset.")
parser.add_option("-b", "--batch_size", dest="batch_size", help="batch size.")
(options, _) = parser.parse_args()
dir_path = options.dir
epoch = options.epoch
mod = "ed" if options.model == "ed" else "vae"
test_model = "y" if options.test else None
pretrained_model_filename = options.pretrained if options.pretrained else None
batch_size = options.batch_size if options.batch_size else config["batch_size"]
return dir_path, mod, epoch, pretrained_model_filename, test_model, int(batch_size)
def main():
start_time = perf_counter()
path, mod, epoch, pretrained_model_filename, test_model, batch_size = gen_parser()
train_model(path, mod, epoch, pretrained_model_filename, test_model, batch_size)
end_time = perf_counter()
print("\n\n--- %s seconds ---" % (end_time - start_time))
if __name__ == "__main__":
main()