rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
2.23 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
import pkgutil
import warnings
from collections import namedtuple
import torch
if torch.__version__ != 'parrots':
def load_ext(name, funcs):
ext = importlib.import_module('mmcv.' + name)
for fun in funcs:
assert hasattr(ext, fun), f'{fun} miss in module {name}'
return ext
else:
from parrots import extension
from parrots.base import ParrotsException
has_return_value_ops = [
'nms',
'softnms',
'nms_match',
'nms_rotated',
'top_pool_forward',
'top_pool_backward',
'bottom_pool_forward',
'bottom_pool_backward',
'left_pool_forward',
'left_pool_backward',
'right_pool_forward',
'right_pool_backward',
'fused_bias_leakyrelu',
'upfirdn2d',
'ms_deform_attn_forward',
'pixel_group',
'contour_expand',
'diff_iou_rotated_sort_vertices_forward',
]
def get_fake_func(name, e):
def fake_func(*args, **kwargs):
warnings.warn(f'{name} is not supported in parrots now')
raise e
return fake_func
def load_ext(name, funcs):
ExtModule = namedtuple('ExtModule', funcs)
ext_list = []
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
for fun in funcs:
try:
ext_fun = extension.load(fun, name, lib_dir=lib_root)
except ParrotsException as e:
if 'No element registered' not in e.message:
warnings.warn(e.message)
ext_fun = get_fake_func(fun, e)
ext_list.append(ext_fun)
else:
if fun in has_return_value_ops:
ext_list.append(ext_fun.op)
else:
ext_list.append(ext_fun.op_)
return ExtModule(*ext_list)
def check_ops_exist() -> bool:
ext_loader = pkgutil.find_loader('mmcv._ext')
return ext_loader is not None