Spaces:
Paused
A newer version of the Gradio SDK is available:
5.5.0
์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ ๊ณต์ ํ๊ธฐ[[sharing-custom-models]]
๐ค Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ์ฝ๊ฒ ํ์ฅํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค. ๋ชจ๋ ๋ชจ๋ธ์ ์ถ์ํ ์์ด ์ ์ฅ์์ ์ง์ ๋ ํ์ ํด๋์ ์์ ํ ์ฝ๋ฉ๋์ด ์์ผ๋ฏ๋ก, ์์ฝ๊ฒ ๋ชจ๋ธ๋ง ํ์ผ์ ๋ณต์ฌํ๊ณ ํ์์ ๋ฐ๋ผ ์กฐ์ ํ ์ ์์ต๋๋ค.
์์ ํ ์๋ก์ด ๋ชจ๋ธ์ ๋ง๋๋ ๊ฒฝ์ฐ์๋ ์ฒ์๋ถํฐ ์์ํ๋ ๊ฒ์ด ๋ ์ฌ์ธ ์ ์์ต๋๋ค. ์ด ํํ ๋ฆฌ์ผ์์๋ Transformers ๋ด์์ ์ฌ์ฉํ ์ ์๋๋ก ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ๊ณผ ๊ตฌ์ฑ์ ์์ฑํ๋ ๋ฐฉ๋ฒ๊ณผ ๐ค Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์๋ ๊ฒฝ์ฐ์๋ ๋๊ตฌ๋ ์ฌ์ฉํ ์ ์๋๋ก (์์กด์ฑ๊ณผ ํจ๊ป) ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์ธ ์ ์์ต๋๋ค.
timm ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ResNet ํด๋์ค๋ฅผ [PreTrainedModel
]๋ก ๋ํํ ResNet ๋ชจ๋ธ์ ์๋ก ๋ชจ๋ ๊ฒ์ ์ค๋ช
ํฉ๋๋ค.
์ฌ์ฉ์ ์ ์ ๊ตฌ์ฑ ์์ฑํ๊ธฐ[[writing-a-custom-configuration]]
๋ชจ๋ธ์ ๋ค์ด๊ฐ๊ธฐ ์ ์ ๋จผ์ ๊ตฌ์ฑ์ ์์ฑํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
๋ชจ๋ธ์ configuration
์ ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ ์ํด ํ์ํ ๋ชจ๋ ์ค์ํ ๊ฒ๋ค์ ํฌํจํ๊ณ ์๋ ๊ฐ์ฒด์
๋๋ค.
๋ค์ ์น์
์์ ๋ณผ ์ ์๋ฏ์ด, ๋ชจ๋ธ์ config
๋ฅผ ์ฌ์ฉํด์๋ง ์ด๊ธฐํํ ์ ์๊ธฐ ๋๋ฌธ์ ์๋ฒฝํ ๊ตฌ์ฑ์ด ํ์ํฉ๋๋ค.
์๋ ์์์์๋ ResNet ํด๋์ค์ ์ธ์(argument)๋ฅผ ์กฐ์ ํด๋ณด๊ฒ ์ต๋๋ค. ๋ค๋ฅธ ๊ตฌ์ฑ์ ๊ฐ๋ฅํ ResNet ์ค ๋ค๋ฅธ ์ ํ์ ์ ๊ณตํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ๋ช ๊ฐ์ง ์ ํจ์ฑ์ ํ์ธํ ํ ํด๋น ์ธ์๋ฅผ ์ ์ฅํฉ๋๋ค.
from transformers import PretrainedConfig
from typing import List
class ResnetConfig(PretrainedConfig):
model_type = "resnet"
def __init__(
self,
block_type="bottleneck",
layers: List[int] = [3, 4, 6, 3],
num_classes: int = 1000,
input_channels: int = 3,
cardinality: int = 1,
base_width: int = 64,
stem_width: int = 64,
stem_type: str = "",
avg_down: bool = False,
**kwargs,
):
if block_type not in ["basic", "bottleneck"]:
raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
if stem_type not in ["", "deep", "deep-tiered"]:
raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
self.block_type = block_type
self.layers = layers
self.num_classes = num_classes
self.input_channels = input_channels
self.cardinality = cardinality
self.base_width = base_width
self.stem_width = stem_width
self.stem_type = stem_type
self.avg_down = avg_down
super().__init__(**kwargs)
์ฌ์ฉ์ ์ ์ configuration
์ ์์ฑํ ๋ ๊ธฐ์ตํด์ผ ํ ์ธ ๊ฐ์ง ์ค์ํ ์ฌํญ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
PretrainedConfig
์ ์์ํด์ผ ํฉ๋๋ค.PretrainedConfig
์__init__
์ ๋ชจ๋ kwargs๋ฅผ ํ์ฉํด์ผ ํ๊ณ ,- ์ด๋ฌํ
kwargs
๋ ์์ ํด๋์ค__init__
์ ์ ๋ฌ๋์ด์ผ ํฉ๋๋ค.
์์์ ๐ค Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ชจ๋ ๊ธฐ๋ฅ์ ๊ฐ์ ธ์ค๋ ๊ฒ์
๋๋ค.
์ด๋ฌํ ์ ์ผ๋ก๋ถํฐ ๋น๋กฏ๋๋ ๋ ๊ฐ์ง ์ ์ฝ ์กฐ๊ฑด์ PretrainedConfig
์ ์ค์ ํ๋ ๊ฒ๋ณด๋ค ๋ ๋ง์ ํ๋๊ฐ ์์ต๋๋ค.
from_pretrained
๋ฉ์๋๋ก ๊ตฌ์ฑ์ ๋ค์ ๋ก๋ํ ๋ ํด๋น ํ๋๋ ๊ตฌ์ฑ์์ ์๋ฝํ ํ ์์ ํด๋์ค๋ก ๋ณด๋ด์ผ ํฉ๋๋ค.
๋ชจ๋ธ์ auto ํด๋์ค์ ๋ฑ๋กํ์ง ์๋ ํ, configuration
์์ model_type
์ ์ ์(์ฌ๊ธฐ์ model_type="resnet"
)ํ๋ ๊ฒ์ ํ์ ์ฌํญ์ด ์๋๋๋ค (๋ง์ง๋ง ์น์
์ฐธ์กฐ).
์ด๋ ๊ฒ ํ๋ฉด ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ค๋ฅธ ๋ชจ๋ธ ๊ตฌ์ฑ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ๊ตฌ์ฑ์ ์ฝ๊ฒ ๋ง๋ค๊ณ ์ ์ฅํ ์ ์์ต๋๋ค. ๋ค์์ resnet50d ๊ตฌ์ฑ์ ์์ฑํ๊ณ ์ ์ฅํ๋ ๋ฐฉ๋ฒ์ ๋๋ค:
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
resnet50d_config.save_pretrained("custom-resnet")
์ด๋ ๊ฒ ํ๋ฉด custom-resnet
ํด๋ ์์ config.json
์ด๋ผ๋ ํ์ผ์ด ์ ์ฅ๋ฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ from_pretrained
๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๊ตฌ์ฑ์ ๋ค์ ๋ก๋ํ ์ ์์ต๋๋ค.
resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
๊ตฌ์ฑ์ Hub์ ์ง์ ์
๋ก๋ํ๊ธฐ ์ํด [PretrainedConfig
] ํด๋์ค์ [~PretrainedConfig.push_to_hub
]์ ๊ฐ์ ๋ค๋ฅธ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ ์์ฑํ๊ธฐ[[writing-a-custom-model]]
์ด์ ResNet ๊ตฌ์ฑ์ด ์์ผ๋ฏ๋ก ๋ชจ๋ธ์ ์์ฑํ ์ ์์ต๋๋ค.
์ค์ ๋ก๋ ๋ ๊ฐ๋ฅผ ์์ฑํ ๊ฒ์
๋๋ค. ํ๋๋ ์ด๋ฏธ์ง ๋ฐฐ์น์์ hidden features๋ฅผ ์ถ์ถํ๋ ๊ฒ([BertModel
]๊ณผ ๊ฐ์ด), ๋ค๋ฅธ ํ๋๋ ์ด๋ฏธ์ง ๋ถ๋ฅ์ ์ ํฉํ ๊ฒ์
๋๋ค([BertForSequenceClassification
]๊ณผ ๊ฐ์ด).
์ด์ ์ ์ธ๊ธํ๋ฏ์ด ์ด ์์ ์์๋ ๋จ์ํ๊ฒ ํ๊ธฐ ์ํด ๋ชจ๋ธ์ ๋์จํ ๋ํผ(loose wrapper)๋ง ์์ฑํ ๊ฒ์
๋๋ค.
์ด ํด๋์ค๋ฅผ ์์ฑํ๊ธฐ ์ ์ ๋ธ๋ก ์ ํ๊ณผ ์ค์ ๋ธ๋ก ํด๋์ค ๊ฐ์ ๋งคํ ์์
๋ง ํ๋ฉด ๋ฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ ResNet
ํด๋์ค๋ก ์ ๋ฌ๋์ด configuration
์ ํตํด ๋ชจ๋ธ์ด ์ ์ธ๋ฉ๋๋ค:
from transformers import PreTrainedModel
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from .configuration_resnet import ResnetConfig
BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
class ResnetModel(PreTrainedModel):
config_class = ResnetConfig
def __init__(self, config):
super().__init__(config)
block_layer = BLOCK_MAPPING[config.block_type]
self.model = ResNet(
block_layer,
config.layers,
num_classes=config.num_classes,
in_chans=config.input_channels,
cardinality=config.cardinality,
base_width=config.base_width,
stem_width=config.stem_width,
stem_type=config.stem_type,
avg_down=config.avg_down,
)
def forward(self, tensor):
return self.model.forward_features(tensor)
์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ ์ํด์๋ forward ๋ฉ์๋๋ง ๋ณ๊ฒฝํ๋ฉด ๋ฉ๋๋ค:
import torch
class ResnetModelForImageClassification(PreTrainedModel):
config_class = ResnetConfig
def __init__(self, config):
super().__init__(config)
block_layer = BLOCK_MAPPING[config.block_type]
self.model = ResNet(
block_layer,
config.layers,
num_classes=config.num_classes,
in_chans=config.input_channels,
cardinality=config.cardinality,
base_width=config.base_width,
stem_width=config.stem_width,
stem_type=config.stem_type,
avg_down=config.avg_down,
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
๋ ๊ฒฝ์ฐ ๋ชจ๋ PreTrainedModel
๋ฅผ ์์๋ฐ๊ณ , config
๋ฅผ ํตํด ์์ ํด๋์ค ์ด๊ธฐํ๋ฅผ ํธ์ถํ๋ค๋ ์ ์ ๊ธฐ์ตํ์ธ์ (์ผ๋ฐ์ ์ธ torch.nn.Module
์ ์์ฑํ ๋์ ๋น์ทํจ).
๋ชจ๋ธ์ auto ํด๋์ค์ ๋ฑ๋กํ๊ณ ์ถ์ ๊ฒฝ์ฐ์๋ config_class
๋ฅผ ์ค์ ํ๋ ๋ถ๋ถ์ด ํ์์
๋๋ค (๋ง์ง๋ง ์น์
์ฐธ์กฐ).
๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์กด์ฌํ๋ ๋ชจ๋ธ๊ณผ ๊ต์ฅํ ์ ์ฌํ๋ค๋ฉด, ๋ชจ๋ธ์ ์์ฑํ ๋ ๊ตฌ์ฑ์ ์ฐธ์กฐํด ์ฌ์ฌ์ฉํ ์ ์์ต๋๋ค.
์ํ๋ ๊ฒ์ ๋ชจ๋ธ์ด ๋ฐํํ๋๋ก ํ ์ ์์ง๋ง, ResnetModelForImageClassification
์์ ํ๋ ๊ฒ ์ฒ๋ผ
๋ ์ด๋ธ์ ํต๊ณผ์์ผฐ์ ๋ ์์ค๊ณผ ํจ๊ป ์ฌ์ ํํ๋ก ๋ฐํํ๋ ๊ฒ์ด [Trainer
] ํด๋์ค ๋ด์์ ์ง์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ์ ์ ์ฉํฉ๋๋ค.
์์ ๋ง์ ํ์ต ๋ฃจํ ๋๋ ๋ค๋ฅธ ํ์ต ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ๊ณํ์ด๋ผ๋ฉด ๋ค๋ฅธ ์ถ๋ ฅ ํ์์ ์ฌ์ฉํด๋ ์ข์ต๋๋ค.
์ด์ ๋ชจ๋ธ ํด๋์ค๊ฐ ์์ผ๋ฏ๋ก ํ๋ ์์ฑํด ๋ณด๊ฒ ์ต๋๋ค:
resnet50d = ResnetModelForImageClassification(resnet50d_config)
๋ค์ ๋งํ์ง๋ง, [~PreTrainedModel.save_pretrained
]๋๋ [~PreTrainedModel.push_to_hub
]์ฒ๋ผ [PreTrainedModel
]์ ์ํ๋ ๋ชจ๋ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๋ค์ ์น์
์์ ๋ ๋ฒ์งธ ๋ฉ์๋๋ฅผ ์ฌ์ฉํด ๋ชจ๋ธ ์ฝ๋์ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์
๋ก๋ํ๋ ๋ฐฉ๋ฒ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋จผ์ , ๋ชจ๋ธ ๋ด๋ถ์ ์ฌ์ ํ๋ จ๋ ๊ฐ์ค์น๋ฅผ ๋ก๋ํด ๋ณด๊ฒ ์ต๋๋ค.
์ด ์์ ๋ฅผ ํ์ฉํ ๋๋, ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ ์์ ๋ง์ ๋ฐ์ดํฐ๋ก ํ์ต์ํฌ ๊ฒ์ ๋๋ค. ์ด ํํ ๋ฆฌ์ผ์์๋ ๋น ๋ฅด๊ฒ ์งํํ๊ธฐ ์ํด ์ฌ์ ํ๋ จ๋ resnet50d๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค. ์๋ ๋ชจ๋ธ์ resnet50d์ ๋ํผ์ด๊ธฐ ๋๋ฌธ์, ๊ฐ์ค์น๋ฅผ ์ฝ๊ฒ ๋ก๋ํ ์ ์์ต๋๋ค.
import timm
pretrained_model = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained_model.state_dict())
์ด์ [~PreTrainedModel.save_pretrained
] ๋๋ [~PreTrainedModel.push_to_hub
]๋ฅผ ์ฌ์ฉํ ๋ ๋ชจ๋ธ ์ฝ๋๊ฐ ์ ์ฅ๋๋์ง ํ์ธํด๋ด
์๋ค.
Hub๋ก ์ฝ๋ ์ ๋ก๋ํ๊ธฐ[[sending-the-code-to-the-hub]]
์ด API๋ ์คํ์ ์ด๋ฉฐ ๋ค์ ๋ฆด๋ฆฌ์ค์์ ์ฝ๊ฐ์ ๋ณ๊ฒฝ ์ฌํญ์ด ์์ ์ ์์ต๋๋ค.
๋จผ์ ๋ชจ๋ธ์ด .py
ํ์ผ์ ์์ ํ ์ ์๋์ด ์๋์ง ํ์ธํ์ธ์.
๋ชจ๋ ํ์ผ์ด ๋์ผํ ์์
๊ฒฝ๋ก์ ์๊ธฐ ๋๋ฌธ์ ์๋๊ฒฝ๋ก ์ํฌํธ(relative import)์ ์์กดํ ์ ์์ต๋๋ค (transformers์์๋ ์ด ๊ธฐ๋ฅ์ ๋ํ ํ์ ๋ชจ๋์ ์ง์ํ์ง ์์ต๋๋ค).
์ด ์์์์๋ ์์
๊ฒฝ๋ก ์์ resnet_model
์์ modeling_resnet.py
ํ์ผ๊ณผ configuration_resnet.py
ํ์ผ์ ์ ์ํฉ๋๋ค.
๊ตฌ์ฑ ํ์ผ์๋ ResnetConfig
์ ๋ํ ์ฝ๋๊ฐ ์๊ณ ๋ชจ๋ธ๋ง ํ์ผ์๋ ResnetModel
๋ฐ ResnetModelForImageClassification
์ ๋ํ ์ฝ๋๊ฐ ์์ต๋๋ค.
.
โโโ resnet_model
โโโ __init__.py
โโโ configuration_resnet.py
โโโ modeling_resnet.py
Python์ด resnet_model
์ ๋ชจ๋๋ก ์ฌ์ฉํ ์ ์๋๋ก ๊ฐ์งํ๋ ๋ชฉ์ ์ด๊ธฐ ๋๋ฌธ์ __init__.py
๋ ๋น์ด ์์ ์ ์์ต๋๋ค.
๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ชจ๋ธ๋ง ํ์ผ์ ๋ณต์ฌํ๋ ๊ฒฝ์ฐ,
๋ชจ๋ ํ์ผ ์๋จ์ ์๋ ์๋ ๊ฒฝ๋ก ์ํฌํธ(relative import) ๋ถ๋ถ์ transformers
ํจํค์ง์์ ์ํฌํธ ํ๋๋ก ๋ณ๊ฒฝํด์ผ ํฉ๋๋ค.
๊ธฐ์กด ๊ตฌ์ฑ์ด๋ ๋ชจ๋ธ์ ์ฌ์ฌ์ฉ(๋๋ ์๋ธ ํด๋์คํ)ํ ์ ์์ต๋๋ค.
์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ๊ธฐ ์ํด์๋ ๋ค์ ๋จ๊ณ๋ฅผ ๋ฐ๋ผ์ผ ํฉ๋๋ค: ๋จผ์ , ์๋ก ๋ง๋ ํ์ผ์ ResNet ๋ชจ๋ธ๊ณผ ๊ตฌ์ฑ์ ์ํฌํธํฉ๋๋ค:
from resnet_model.configuration_resnet import ResnetConfig
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
๋ค์์ผ๋ก save_pretrained
๋ฉ์๋๋ฅผ ์ฌ์ฉํด ํด๋น ๊ฐ์ฒด์ ์ฝ๋ ํ์ผ์ ๋ณต์ฌํ๊ณ ,
๋ณต์ฌํ ํ์ผ์ Auto ํด๋์ค๋ก ๋ฑ๋กํ๊ณ (๋ชจ๋ธ์ธ ๊ฒฝ์ฐ) ์คํํฉ๋๋ค:
ResnetConfig.register_for_auto_class()
ResnetModel.register_for_auto_class("AutoModel")
ResnetModelForImageClassification.register_for_auto_class("AutoModelForImageClassification")
configuration
์ ๋ํ auto ํด๋์ค๋ฅผ ์ง์ ํ ํ์๋ ์์ง๋ง(configuration
๊ด๋ จ auto ํด๋์ค๋ AutoConfig ํด๋์ค ํ๋๋ง ์์), ๋ชจ๋ธ์ ๊ฒฝ์ฐ์๋ ์ง์ ํด์ผ ํฉ๋๋ค.
์ฌ์ฉ์ ์ง์ ๋ชจ๋ธ์ ๋ค์ํ ์์
์ ์ ํฉํ ์ ์์ผ๋ฏ๋ก, ๋ชจ๋ธ์ ๋ง๋ auto ํด๋์ค๋ฅผ ์ง์ ํด์ผ ํฉ๋๋ค.
๋ค์์ผ๋ก, ์ด์ ์ ์์ ํ๋ ๊ฒ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ๊ตฌ์ฑ๊ณผ ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค:
resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
resnet50d = ResnetModelForImageClassification(resnet50d_config)
pretrained_model = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained_model.state_dict())
์ด์ ๋ชจ๋ธ์ Hub๋ก ์ ๋ก๋ํ๊ธฐ ์ํด ๋ก๊ทธ์ธ ์ํ์ธ์ง ํ์ธํ์ธ์. ํฐ๋ฏธ๋์์ ๋ค์ ์ฝ๋๋ฅผ ์คํํด ํ์ธํ ์ ์์ต๋๋ค:
huggingface-cli login
์ฃผํผํฐ ๋ ธํธ๋ถ์ ๊ฒฝ์ฐ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
from huggingface_hub import notebook_login
notebook_login()
๊ทธ๋ฐ ๋ค์ ์ด๋ ๊ฒ ์์ ์ ๋ค์์คํ์ด์ค(๋๋ ์์ ์ด ์ํ ์กฐ์ง)์ ์ ๋ก๋ํ ์ ์์ต๋๋ค:
resnet50d.push_to_hub("custom-resnet50d")
On top of the modeling weights and the configuration in json format, this also copied the modeling and
configuration .py
files in the folder custom-resnet50d
and uploaded the result to the Hub. You can check the result
in this model repo.
json ํ์์ ๋ชจ๋ธ๋ง ๊ฐ์ค์น์ ๊ตฌ์ฑ ์ธ์๋ custom-resnet50d
ํด๋ ์์ ๋ชจ๋ธ๋ง๊ณผ ๊ตฌ์ฑ .py
ํ์ผ์ ๋ณต์ฌํํด Hub์ ์
๋ก๋ํฉ๋๋ค.
๋ชจ๋ธ ์ ์ฅ์์์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ ์ ์์ต๋๋ค.
sharing tutorial ๋ฌธ์์ push_to_hub
๋ฉ์๋์์ ์์ธํ ๋ด์ฉ์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ์ฉ์ ์ ์ ์ฝ๋๋ก ๋ชจ๋ธ ์ฌ์ฉํ๊ธฐ[[using-a-model-with-custom-code]]
auto ํด๋์ค์ from_pretrained
๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์ ์ง์ ์ฝ๋ ํ์ผ๊ณผ ํจ๊ป ๋ชจ๋ ๊ตฌ์ฑ, ๋ชจ๋ธ, ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
Hub์ ์
๋ก๋๋ ๋ชจ๋ ํ์ผ ๋ฐ ์ฝ๋๋ ๋ฉ์จ์ด๊ฐ ์๋์ง ๊ฒ์ฌ๋์ง๋ง (์์ธํ ๋ด์ฉ์ Hub ๋ณด์ ์ค๋ช
์ฐธ์กฐ),
์์ ์ ์ปดํจํฐ์์ ๋ชจ๋ธ ์ฝ๋์ ์์ฑ์๊ฐ ์
์ฑ ์ฝ๋๋ฅผ ์คํํ์ง ์๋์ง ํ์ธํด์ผ ํฉ๋๋ค.
์ฌ์ฉ์ ์ ์ ์ฝ๋๋ก ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด trust_remote_code=True
๋ก ์ค์ ํ์ธ์:
from transformers import AutoModelForImageClassification
model = AutoModelForImageClassification.from_pretrained("sgugger/custom-resnet50d", trust_remote_code=True)
๋ชจ๋ธ ์์ฑ์๊ฐ ์
์์ ์ผ๋ก ์ฝ๋๋ฅผ ์
๋ฐ์ดํธํ์ง ์์๋ค๋ ์ ์ ํ์ธํ๊ธฐ ์ํด, ์ปค๋ฐ ํด์(commit hash)๋ฅผ revision
์ผ๋ก ์ ๋ฌํ๋ ๊ฒ๋ ๊ฐ๋ ฅํ ๊ถ์ฅ๋ฉ๋๋ค (๋ชจ๋ธ ์์ฑ์๋ฅผ ์์ ํ ์ ๋ขฐํ์ง ์๋ ๊ฒฝ์ฐ).
commit_hash = "ed94a7c6247d8aedce4647f00f20de6875b5b292"
model = AutoModelForImageClassification.from_pretrained(
"sgugger/custom-resnet50d", trust_remote_code=True, revision=commit_hash
)
Hub์์ ๋ชจ๋ธ ์ ์ฅ์์ ์ปค๋ฐ ๊ธฐ๋ก์ ์ฐพ์๋ณผ ๋, ๋ชจ๋ ์ปค๋ฐ์ ์ปค๋ฐ ํด์๋ฅผ ์ฝ๊ฒ ๋ณต์ฌํ ์ ์๋ ๋ฒํผ์ด ์์ต๋๋ค.
์ฌ์ฉ์ ์ ์ ์ฝ๋๋ก ๋ง๋ ๋ชจ๋ธ์ auto ํด๋์ค๋ก ๋ฑ๋กํ๊ธฐ[[registering-a-model-with-custom-code-to-the-auto-classes]]
๐ค Transformers๋ฅผ ์์ํ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์์ฑํ๋ ๊ฒฝ์ฐ ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ auto ํด๋์ค์ ์ถ๊ฐํ ์ ์์ต๋๋ค. ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๊ธฐ ์ํด ํด๋น ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ํฌํธํด์ผ ํ๊ธฐ ๋๋ฌธ์, ์ด๋ Hub๋ก ์ฝ๋๋ฅผ ์ ๋ก๋ํ๋ ๊ฒ๊ณผ ๋ค๋ฆ ๋๋ค (Hub์์ ์๋์ ์ผ๋ก ๋ชจ๋ธ ์ฝ๋๋ฅผ ๋ค์ด๋ก๋ ํ๋ ๊ฒ๊ณผ ๋ฐ๋).
๊ตฌ์ฑ์ ๊ธฐ์กด ๋ชจ๋ธ ์ ํ๊ณผ ๋ค๋ฅธ model_type
์์ฑ์ด ์๊ณ ๋ชจ๋ธ ํด๋์ค์ ์ฌ๋ฐ๋ฅธ config_class
์์ฑ์ด ์๋ ํ,
๋ค์๊ณผ ๊ฐ์ด auto ํด๋์ค์ ์ถ๊ฐํ ์ ์์ต๋๋ค:
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
AutoConfig.register("resnet", ResnetConfig)
AutoModel.register(ResnetConfig, ResnetModel)
AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)
์ฌ์ฉ์ ์ ์ ๊ตฌ์ฑ์ [AutoConfig
]์ ๋ฑ๋กํ ๋ ์ฌ์ฉ๋๋ ์ฒซ ๋ฒ์งธ ์ธ์๋ ์ฌ์ฉ์ ์ ์ ๊ตฌ์ฑ์ model_type
๊ณผ ์ผ์นํด์ผ ํฉ๋๋ค.
๋ํ, ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ์ auto ํด๋์ค์ ๋ฑ๋กํ ๋ ์ฌ์ฉ๋๋ ์ฒซ ๋ฒ์งธ ์ธ์๋ ํด๋น ๋ชจ๋ธ์ config_class
์ ์ผ์นํด์ผ ํฉ๋๋ค.