Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, visit | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import torch | |
import torchvision.transforms as transforms | |
def load_expert_model(task=None): | |
if task == 'depth': | |
# DPT model is a standard pytorch model class | |
from experts.depth.models import DPTDepthModel | |
model = DPTDepthModel(path='experts/expert_weights/dpt_hybrid-midas-501f0c75.pt', | |
backbone="vitb_rn50_384", | |
non_negative=True, | |
enable_attention_hooks=False) | |
transform = transforms.Compose([ | |
transforms.Resize([480, 480]), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=0.5, std=0.5)] | |
) | |
elif task == 'seg_coco': | |
# Mask2Former is wrapped in detection2, | |
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width} | |
import argparse | |
from detectron2.engine.defaults import DefaultPredictor | |
from experts.segmentation.utils import setup_cfg | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", default="client") | |
parser.add_argument("--port", default=2) | |
args = parser.parse_args() | |
args.config_file = 'experts/segmentation/configs/coco/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_100ep.yaml' | |
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_f07440.pkl'] | |
cfg = setup_cfg(args) | |
model = DefaultPredictor(cfg).model | |
transform = transforms.Compose([ | |
transforms.Resize(size=479, max_size=480) | |
]) | |
elif task == 'seg_ade': | |
# Mask2Former is wrapped in detection2, | |
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width} | |
import argparse | |
from detectron2.engine.defaults import DefaultPredictor | |
from experts.segmentation.utils import setup_cfg | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", default="client") | |
parser.add_argument("--port", default=2) | |
args = parser.parse_args() | |
args.config_file = 'experts/segmentation/configs/ade20k/panoptic-segmentation/swin/maskformer2_swin_large_IN21k_384_bs16_160k.yaml' | |
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/model_final_e0c58e.pkl'] | |
cfg = setup_cfg(args) | |
model = DefaultPredictor(cfg).model | |
transform = transforms.Compose([ | |
transforms.Resize(size=479, max_size=480) | |
]) | |
elif task == 'obj_detection': | |
# UniDet is wrapped in detection2, | |
# the model takes input in the format of: {"image": image (BGR), "height": height, "width": width} | |
import argparse | |
from detectron2.engine.defaults import DefaultPredictor | |
from experts.obj_detection.utils import setup_cfg | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", default="client") | |
parser.add_argument("--port", default=2) | |
parser.add_argument("--confidence-threshold", type=float, default=0.5) | |
args = parser.parse_args() | |
args.config_file = 'experts/obj_detection/configs/Unified_learned_OCIM_RS200_6x+2x.yaml' | |
args.opts = ['MODEL.WEIGHTS', 'experts/expert_weights/Unified_learned_OCIM_RS200_6x+2x.pth'] | |
cfg = setup_cfg(args) | |
model = DefaultPredictor(cfg).model | |
transform = transforms.Compose([ | |
transforms.Resize(size=479, max_size=480) | |
]) | |
elif task == 'ocr_detection': | |
from experts.ocr_detection.charnet.modeling.model import CharNet | |
model = CharNet() | |
model.load_state_dict(torch.load('experts/expert_weights/icdar2015_hourglass88.pth')) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
elif task == 'normal': | |
# NLL-AngMF model is a standard pytorch model class | |
import argparse | |
from experts.normal.models.NNET import NNET | |
from experts.normal.utils import utils | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--mode", default="client") | |
parser.add_argument("--port", default=2) | |
parser.add_argument('--architecture', default='BN', type=str, help='{BN, GN}') | |
parser.add_argument("--pretrained", default='scannet', type=str, help="{nyu, scannet}") | |
parser.add_argument('--sampling_ratio', type=float, default=0.4) | |
parser.add_argument('--importance_ratio', type=float, default=0.7) | |
args = parser.parse_args() | |
model = NNET(args) | |
model = utils.load_checkpoint('experts/expert_weights/scannet.pt', model) | |
transform = transforms.Compose([ | |
transforms.Resize([480, 480]), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
elif task == 'edge': | |
# NLL-AngMF model is a standard pytorch model class | |
from experts.edge.model import DexiNed | |
model = DexiNed() | |
model.load_state_dict(torch.load('experts/expert_weights/10_model.pth', map_location='cpu')) | |
transform = transforms.Compose([ | |
transforms.Resize([480, 480]), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.0, 1.0, 1.0]) | |
]) | |
else: | |
print('Task not supported') | |
model = None | |
transform = None | |
model.eval() | |
return model, transform | |