SivaResearch's picture
demo
b6d5990
raw
history blame contribute delete
No virus
3.42 kB
import onnx
import torch
import argparse
import numpy as np
import torch.nn as nn
from models.TMC import ETMC
from models import image
from onnx2pytorch import ConvertModel
onnx_model = onnx.load('checkpoints\\efficientnet.onnx')
pytorch_model = ConvertModel(onnx_model)
# Define the audio_args dictionary
audio_args = {
'nb_samp': 64600,
'first_conv': 1024,
'in_channels': 1,
'filts': [20, [20, 20], [20, 128], [128, 128]],
'blocks': [2, 4],
'nb_fc_node': 1024,
'gru_node': 1024,
'nb_gru_layer': 3,
'nb_classes': 2
}
def get_args(parser):
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--data_dir", type=str, default="datasets/train/fakeavceleb*")
parser.add_argument("--LOAD_SIZE", type=int, default=256)
parser.add_argument("--FINE_SIZE", type=int, default=224)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--hidden", nargs="*", type=int, default=[])
parser.add_argument("--hidden_sz", type=int, default=768)
parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"])
parser.add_argument("--img_hidden_sz", type=int, default=1024)
parser.add_argument("--include_bn", type=int, default=True)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_factor", type=float, default=0.3)
parser.add_argument("--lr_patience", type=int, default=10)
parser.add_argument("--max_epochs", type=int, default=500)
parser.add_argument("--n_workers", type=int, default=12)
parser.add_argument("--name", type=str, default="MMDF")
parser.add_argument("--num_image_embeds", type=int, default=1)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--savedir", type=str, default="./savepath/")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--n_classes", type=int, default=2)
parser.add_argument("--annealing_epoch", type=int, default=10)
parser.add_argument("--device", type=str, default='cpu')
parser.add_argument("--pretrained_image_encoder", type=bool, default = False)
parser.add_argument("--freeze_image_encoder", type=bool, default = False)
parser.add_argument("--pretrained_audio_encoder", type = bool, default=False)
parser.add_argument("--freeze_audio_encoder", type = bool, default = False)
parser.add_argument("--augment_dataset", type = bool, default = True)
for key, value in audio_args.items():
parser.add_argument(f"--{key}", type=type(value), default=value)
def load_spec_modality_model(args):
spec_encoder = image.RawNet(args)
ckpt = torch.load('checkpoints\RawNet2.pth', map_location = torch.device('cpu'))
spec_encoder.load_state_dict(ckpt, strict = True)
spec_encoder.eval()
return spec_encoder
#Load models.
parser = argparse.ArgumentParser(description="Train Models")
get_args(parser)
args, remaining_args = parser.parse_known_args()
assert remaining_args == [], remaining_args
spec_model = load_spec_modality_model(args)
print(f"Image model is: {pytorch_model}")
print(f"Audio model is: {spec_model}")
PATH = 'checkpoints\\model.pth'
torch.save({
'spec_encoder': spec_model.state_dict(),
'rgb_encoder': pytorch_model.state_dict()
}, PATH)
print("Model saved.")