|
import onnx |
|
import torch |
|
import argparse |
|
import numpy as np |
|
import torch.nn as nn |
|
from models.TMC import ETMC |
|
from models import image |
|
from onnx2pytorch import ConvertModel |
|
|
|
onnx_model = onnx.load('checkpoints\\efficientnet.onnx') |
|
pytorch_model = ConvertModel(onnx_model) |
|
|
|
|
|
audio_args = { |
|
'nb_samp': 64600, |
|
'first_conv': 1024, |
|
'in_channels': 1, |
|
'filts': [20, [20, 20], [20, 128], [128, 128]], |
|
'blocks': [2, 4], |
|
'nb_fc_node': 1024, |
|
'gru_node': 1024, |
|
'nb_gru_layer': 3, |
|
'nb_classes': 2 |
|
} |
|
|
|
|
|
def get_args(parser): |
|
parser.add_argument("--batch_size", type=int, default=8) |
|
parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*") |
|
parser.add_argument("--LOAD_SIZE", type=int, default=256) |
|
parser.add_argument("--FINE_SIZE", type=int, default=224) |
|
parser.add_argument("--dropout", type=float, default=0.2) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) |
|
parser.add_argument("--hidden", nargs="*", type=int, default=[]) |
|
parser.add_argument("--hidden_sz", type=int, default=768) |
|
parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"]) |
|
parser.add_argument("--img_hidden_sz", type=int, default=1024) |
|
parser.add_argument("--include_bn", type=int, default=True) |
|
parser.add_argument("--lr", type=float, default=1e-4) |
|
parser.add_argument("--lr_factor", type=float, default=0.3) |
|
parser.add_argument("--lr_patience", type=int, default=10) |
|
parser.add_argument("--max_epochs", type=int, default=500) |
|
parser.add_argument("--n_workers", type=int, default=12) |
|
parser.add_argument("--name", type=str, default="MMDF") |
|
parser.add_argument("--num_image_embeds", type=int, default=1) |
|
parser.add_argument("--patience", type=int, default=20) |
|
parser.add_argument("--savedir", type=str, default="./savepath/") |
|
parser.add_argument("--seed", type=int, default=1) |
|
parser.add_argument("--n_classes", type=int, default=2) |
|
parser.add_argument("--annealing_epoch", type=int, default=10) |
|
parser.add_argument("--device", type=str, default='cpu') |
|
parser.add_argument("--pretrained_image_encoder", type=bool, default = False) |
|
parser.add_argument("--freeze_image_encoder", type=bool, default = False) |
|
parser.add_argument("--pretrained_audio_encoder", type = bool, default=False) |
|
parser.add_argument("--freeze_audio_encoder", type = bool, default = False) |
|
parser.add_argument("--augment_dataset", type = bool, default = True) |
|
|
|
for key, value in audio_args.items(): |
|
parser.add_argument(f"--{key}", type=type(value), default=value) |
|
|
|
def load_spec_modality_model(args): |
|
spec_encoder = image.RawNet(args) |
|
ckpt = torch.load('checkpoints\RawNet2.pth', map_location = torch.device('cpu')) |
|
spec_encoder.load_state_dict(ckpt, strict = True) |
|
spec_encoder.eval() |
|
return spec_encoder |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train Models") |
|
get_args(parser) |
|
args, remaining_args = parser.parse_known_args() |
|
assert remaining_args == [], remaining_args |
|
|
|
spec_model = load_spec_modality_model(args) |
|
|
|
print(f"Image model is: {pytorch_model}") |
|
|
|
print(f"Audio model is: {spec_model}") |
|
|
|
|
|
PATH = 'checkpoints\\model.pth' |
|
|
|
torch.save({ |
|
'spec_encoder': spec_model.state_dict(), |
|
'rgb_encoder': pytorch_model.state_dict() |
|
}, PATH) |
|
|
|
print("Model saved.") |