my-cool-model / utils /benchmark.py
crapthings's picture
Upload folder using huggingface_hub
f7f604d
import argparse
import torch
import os
import sys
from thop import profile, clever_format
import warnings
warnings.filterwarnings("ignore")
filepath = os.path.split(__file__)[0]
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from lib import *
from lib.optim import *
from data.dataloader import *
from utils.misc import *
def _args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/InSPyReNet_SwinB.yaml')
parser.add_argument('--input_size', type=int, nargs='+', default=[384, 384])
parser.add_argument('--verbose', action='store_true', default=False)
return parser.parse_args()
def benchmark(opt, args):
model = Simplify(eval(opt.Model.name)(**opt.Model))
model = model.cuda()
input = torch.rand(1, 3, *args.input_size)
input = input.cuda()
macs, params = profile(model, inputs=(input, ), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
with torch.no_grad():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(10):
out = model(input)
end.record()
# Waits for everything to finish running
torch.cuda.synchronize()
print('Model:', opt.Model.name)
print('MACs:', macs, 'Params:', params)
print('Throughput:', start.elapsed_time(end) / 10, 'msec')
if __name__ == '__main__':
args = _args()
opt = load_config(args.config)
benchmark(opt, args)