test / modules /patches.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
raw
history blame contribute delete
No virus
2.43 kB
from collections import defaultdict
from typing import Optional
from modules.errors import log
def patch(key, obj, field, replacement, add_if_not_exists:bool = False):
"""Replaces a function in a module or a class.
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
replacement: the new function
Returns:
the original function
"""
patch_key = (obj, field)
if patch_key in originals[key]:
log.error(f"Patch already applied: field={field}")
if not hasattr(obj, field) and not add_if_not_exists:
log.error(f"Patch no attribute: type={type(obj)} name='{type.__name__}' fiel'{field}'")
return None
original_func = getattr(obj, field, None)
originals[key][patch_key] = original_func
setattr(obj, field, replacement)
return original_func
def undo(key, obj, field):
"""Undoes the peplacement by the patch().
If the function is not replaced, raises an exception.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
Returns:
Always None
"""
patch_key = (obj, field)
if patch_key not in originals[key]:
log.error(f"Patch no patch to undo: field={field}")
return
original_func = originals[key].pop(patch_key)
if original_func is None:
delattr(obj, field)
setattr(obj, field, original_func)
return None
def original(key, obj, field):
"""Returns the original function for the patch created by the patch() function"""
patch_key = (obj, field)
return originals[key].get(patch_key, None)
def patch_method(cls, key:Optional[str]=None):
def decorator(func):
patch(func.__module__ if key is None else key, cls, func.__name__, func)
return decorator
def add_method(cls, key:Optional[str]=None):
def decorator(func):
patch(func.__module__ if key is None else key, cls, func.__name__, func, True)
return decorator
originals = defaultdict(dict)