prismer / prismer /experts /model_bank.py
shikunl's picture
Fix resolution
82b5fea
raw history blame
No virus
5.76 kB
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import torch
import torchvision.transforms as transforms
def load_expert_model(task=None):
if task == 'depth':
# DPT model is a standard pytorch model class
from experts.depth.models import DPTDepthModel
model = DPTDepthModel(path='experts/expert_weights/dpt_hybrid-midas-501f0c75.pt',
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False)
transform = transforms.Compose([
transforms.Resize([480, 480]),
transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)]
)
elif task == 'seg_coco':
# Mask2Former is wrapped in detection2,
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
import argparse
from detectron2.engine.defaults import DefaultPredictor
from experts.segmentation.utils import setup_cfg
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="client")
parser.add_argument("--port", default=2)
args = parser.parse_args()
args.config_file = 'experts/segmentation/configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml'
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_f07440.pkl']
cfg = setup_cfg(args)
model = DefaultPredictor(cfg).model
transform = transforms.Compose([
transforms.Resize([480, 480])
])
elif task == 'seg_ade':
# Mask2Former is wrapped in detection2,
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
import argparse
from detectron2.engine.defaults import DefaultPredictor
from experts.segmentation.utils import setup_cfg
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="client")
parser.add_argument("--port", default=2)
args = parser.parse_args()
args.config_file = 'experts/segmentation/configs/ade20k/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_160k.yaml'
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_e0c58e.pkl']
cfg = setup_cfg(args)
model = DefaultPredictor(cfg).model
transform = transforms.Compose([
transforms.Resize([480, 480])
])
elif task == 'obj_detection':
# UniDet is wrapped in detection2,
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width}
import argparse
from detectron2.engine.defaults import DefaultPredictor
from experts.obj_detection.utils import setup_cfg
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="client")
parser.add_argument("--port", default=2)
parser.add_argument("--confidence-threshold", type=float, default=0.5)
args = parser.parse_args()
args.config_file = 'experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml'
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/Unified_learned_OCIM_RS200_6x+2x.pth']
cfg = setup_cfg(args)
model = DefaultPredictor(cfg).model
transform = transforms.Compose([
transforms.Resize([480, 480])
])
elif task == 'ocr_detection':
from experts.ocr_detection.charnet.modeling.model import CharNet
model = CharNet()
model.load_state_dict(torch.load('experts/expert_weights/icdar2015_hourglass88.pth'))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
elif task == 'normal':
# NLL-AngMF model is a standard pytorch model class
import argparse
from experts.normal.models.NNET import NNET
from experts.normal.utils import utils
parser = argparse.ArgumentParser()
parser.add_argument("--mode", default="client")
parser.add_argument("--port", default=2)
parser.add_argument('--architecture', default='BN', type=str, help='{BN, GN}')
parser.add_argument("--pretrained", default='scannet', type=str, help="{nyu, scannet}")
parser.add_argument('--sampling_ratio', type=float, default=0.4)
parser.add_argument('--importance_ratio', type=float, default=0.7)
args = parser.parse_args()
model = NNET(args)
model = utils.load_checkpoint('experts/expert_weights/scannet.pt', model)
transform = transforms.Compose([
transforms.Resize([480, 480]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
elif task == 'edge':
# NLL-AngMF model is a standard pytorch model class
from experts.edge.model import DexiNed
model = DexiNed()
model.load_state_dict(torch.load('experts/expert_weights/10_model.pth', map_location='cpu'))
transform = transforms.Compose([
transforms.Resize([480, 480]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.0, 1.0, 1.0])
])
else:
print('Task not supported')
model = None
transform = None
model.eval()
return model, transform