# Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. import inspect from types import FunctionType, MethodType from typing import List, Union from peft import PeftModel from torch.nn import Module from swift.utils import get_logger logger = get_logger() def can_return_loss(model: Module) -> bool: """Check if a given model can return loss.""" if isinstance(model, PeftModel): signature = inspect.signature(model.model.forward) else: signature = inspect.signature(model.forward) for p in signature.parameters: if p == 'return_loss' and signature.parameters[p].default is True: return True return False def find_labels(model: Module) -> List[str]: """Find the labels used by a given model.""" model_name = model.__class__.__name__ if isinstance(model, PeftModel): signature = inspect.signature(model.model.forward) else: signature = inspect.signature(model.forward) if 'QuestionAnswering' in model_name: return [p for p in signature.parameters if 'label' in p or p in ('start_positions', 'end_positions')] else: return [p for p in signature.parameters if 'label' in p] def get_function(method_or_function: Union[MethodType, FunctionType]) -> FunctionType: if isinstance(method_or_function, MethodType): method_or_function = method_or_function.__func__ return method_or_function def is_instance_of_ms_model(model: Module) -> bool: """avoid import modelscope: circular dependency problem""" for m_cls in model.__class__.__mro__: cls_name = m_cls.__name__ cls_module = m_cls.__module__ if cls_name == 'Model' and cls_module.startswith('modelscope'): return True return False