| """ | |
| python inference_speed_test.py \ | |
| --model-variant mobilenetv3 \ | |
| --resolution 1920 1080 \ | |
| --downsample-ratio 0.25 \ | |
| --precision float32 | |
| """ | |
| import argparse | |
| import torch | |
| from tqdm import tqdm | |
| from model.model import MattingNetwork | |
| torch.backends.cudnn.benchmark = True | |
| class InferenceSpeedTest: | |
| def __init__(self): | |
| self.parse_args() | |
| self.init_model() | |
| self.loop() | |
| def parse_args(self): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--model-variant', type=str, required=True) | |
| parser.add_argument('--resolution', type=int, required=True, nargs=2) | |
| parser.add_argument('--downsample-ratio', type=float, required=True) | |
| parser.add_argument('--precision', type=str, default='float32') | |
| parser.add_argument('--disable-refiner', action='store_true') | |
| self.args = parser.parse_args() | |
| def init_model(self): | |
| self.device = 'cuda' | |
| self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision] | |
| self.model = MattingNetwork(self.args.model_variant) | |
| self.model = self.model.to(device=self.device, dtype=self.precision).eval() | |
| self.model = torch.jit.script(self.model) | |
| self.model = torch.jit.freeze(self.model) | |
| def loop(self): | |
| w, h = self.args.resolution | |
| src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision) | |
| with torch.no_grad(): | |
| rec = None, None, None, None | |
| for _ in tqdm(range(1000)): | |
| fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio) | |
| torch.cuda.synchronize() | |
| if __name__ == '__main__': | |
| InferenceSpeedTest() |