DSTK / semantic_tokenizer /f40ms /patch_utils.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# copied from https://gitee.com/ascend/MindSpeed/blob/master/mindspeed/patch_utils.py
import importlib
import sys
import types
def get_func_name(func):
if isinstance(func, str):
return func
return '.'.join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError('function {} no exist'.format(func_name))
return dummy_function
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy):
split_name = orig_func_name.rsplit('.', 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module = None
self.orig_func = None
self.patch_func = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
@property
def patch_func_id(self):
return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False):
if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')):
self.wrappers.append(new_func)
else:
if self.patch_func and not force_patch:
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
self.patch_func = new_func
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
self.orig_module, self.orig_func = Patch.parse_path(self.orig_module_name, self.orig_func_name, self.create_dummy)
final_patch_func = self.orig_func
if self.patch_func is not None:
final_patch_func = self.patch_func
for wrapper in self.wrappers:
final_patch_func = wrapper(final_patch_func)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, final_patch_func)
for key, value in sys.modules.copy().items():
if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
and id(getattr(value, self.orig_func_name)) == self.orig_func_id:
setattr(value, self.orig_func_name, final_patch_func)
self.is_applied = True
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split('.')
for i in range(1, len(modules) + 1):
parent = '.'.join(modules[:i - 1])
path = '.'.join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = 'mindspeed.dummy_module.py'
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError('no exist {} of {}'.format(function_name, module))
if function_name is not None and not hasattr(sys.modules[module_path], function_name):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], getattr(sys.modules[module_path], function_name) if function_name is not None else None
class MindSpeedPatchesManager:
patches_info = {}
@staticmethod
def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
if orig_func_name not in MindSpeedPatchesManager.patches_info:
MindSpeedPatchesManager.patches_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
MindSpeedPatchesManager.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)
@staticmethod
def apply_patches():
for patch in MindSpeedPatchesManager.patches_info.values():
patch.apply_patch()