from transformers import PreTrainedModel | |
from uniformer_finetune_config import UniformerXXSFinetuneConfig | |
from uniformer_xs import UniformerXXSFinetune | |
class UniformerXXSFinetuneModel(PreTrainedModel): | |
config_class = UniformerXXSFinetuneConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = UniformerXXSFinetune( | |
out_class=config.out_class | |
) | |
def forward(self, tensor): | |
return self.model.forward(tensor) |