DeCRED-small / auto_wrappers.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
c0f8a54 verified
raw
history blame
6.04 kB
import copy
import os
from transformers import AutoConfig, AutoModelForCTC, PretrainedConfig
from transformers.dynamic_module_utils import (
get_class_from_dynamic_module,
resolve_trust_remote_code,
)
from transformers.models.auto.auto_factory import _get_model_class
from .extractors import Conv2dFeatureExtractor
class FeatureExtractionInitModifier(type):
def __new__(cls, name, bases, dct):
# Create the class using the original definition
new_cls = super().__new__(cls, name, bases, dct)
# Save the original __init__ method
original_init = new_cls.__init__
# Modify the __init__ method dynamically
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if self.config.expect_2d_input:
getattr(self, self.base_model_prefix).feature_extractor = Conv2dFeatureExtractor(self.config)
# Replace the __init__ method with the modified version
new_cls.__init__ = new_init
return new_cls
class CustomAutoModelForCTC(AutoModelForCTC):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
kwargs["_from_auto"] = True
hub_kwargs_names = [
"cache_dir",
"code_revision",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"use_auth_token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
if not isinstance(config, PretrainedConfig):
kwargs_orig = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
# meaningless in the context of the config object - torch.dtype values are acceptable
if kwargs.get("torch_dtype", None) == "auto":
_ = kwargs.pop("torch_dtype")
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs,
)
# if torch_dtype=auto was passed here, ensure to pass it on
if kwargs_orig.get("torch_dtype", None) == "auto":
kwargs["torch_dtype"] = "auto"
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
)
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
_ = hub_kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
@classmethod
def from_config(cls, config, **kwargs):
trust_remote_code = kwargs.pop("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, config._name_or_path, has_local_code, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
if "--" in class_ref:
repo_id, class_ref = class_ref.split("--")
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
if os.path.isdir(config._name_or_path):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
model_class = FeatureExtractionInitModifier(model_class.__name__, (model_class,), {})
return model_class._from_config(config, **kwargs)
raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)