|
|
from typing import Dict
|
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
|
|
|
|
|
MODEL_NAMES = [
|
|
|
'efficientnet_b0',
|
|
|
'efficientnet_b1',
|
|
|
'efficientnet_b2',
|
|
|
'efficientnet_b3',
|
|
|
'efficientnet_b4',
|
|
|
'efficientnet_b5',
|
|
|
'efficientnet_b6',
|
|
|
'efficientnet_b7',
|
|
|
'efficientnet_b8',
|
|
|
'efficientnet_l2'
|
|
|
]
|
|
|
|
|
|
|
|
|
class EfficientNetConfig(PretrainedConfig):
|
|
|
model_type = 'efficientnet'
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model_name: str = 'efficientnet_b0',
|
|
|
pretrained: bool = False,
|
|
|
num_classes: int = 1000,
|
|
|
global_pool: str = 'avg',
|
|
|
**kwargs,
|
|
|
):
|
|
|
if model_name not in MODEL_NAMES:
|
|
|
raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
|
|
|
|
|
|
self.model_name = model_name
|
|
|
self.pretrained = pretrained
|
|
|
self.num_classes = num_classes
|
|
|
self.global_pool = global_pool
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
|
class EfficientNetOnnxConfig(ViTOnnxConfig):
|
|
|
@property
|
|
|
def outputs(self) -> Dict[str, Dict[int, str]]:
|
|
|
common_outputs = super().outputs
|
|
|
|
|
|
if self.task == "image-classification":
|
|
|
common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
|
|
|
|
|
|
return common_outputs
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
'EfficientNetConfig',
|
|
|
'EfficientNetOnnxConfig'
|
|
|
] |