|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path as osp |
|
import logging |
|
|
|
from dataclasses import dataclass |
|
import torch |
|
from torchvision import transforms |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.tasks import register_task |
|
from fairseq.logging import metrics |
|
|
|
try: |
|
from ..data import ImageDataset |
|
except: |
|
import sys |
|
|
|
sys.path.append("..") |
|
from data import ImageDataset |
|
|
|
from .image_pretraining import ( |
|
ImagePretrainingConfig, |
|
ImagePretrainingTask, |
|
IMG_EXTENSIONS, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ImageClassificationConfig(ImagePretrainingConfig): |
|
pass |
|
|
|
|
|
@register_task("image_classification", dataclass=ImageClassificationConfig) |
|
class ImageClassificationTask(ImagePretrainingTask): |
|
|
|
cfg: ImageClassificationConfig |
|
|
|
@classmethod |
|
def setup_task(cls, cfg: ImageClassificationConfig, **kwargs): |
|
return cls(cfg) |
|
|
|
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): |
|
data_path = self.cfg.data |
|
cfg = task_cfg or self.cfg |
|
|
|
path_with_split = osp.join(data_path, split) |
|
if osp.exists(path_with_split): |
|
data_path = path_with_split |
|
|
|
from timm.data import create_transform |
|
|
|
if split == "train": |
|
|
|
transform = create_transform( |
|
input_size=cfg.input_size, |
|
is_training=True, |
|
auto_augment="rand-m9-mstd0.5-inc1", |
|
interpolation="bicubic", |
|
re_prob=0.25, |
|
re_mode="pixel", |
|
re_count=1, |
|
mean=cfg.normalization_mean, |
|
std=cfg.normalization_std, |
|
) |
|
if not cfg.input_size > 32: |
|
transform.transforms[0] = transforms.RandomCrop( |
|
cfg.input_size, padding=4 |
|
) |
|
else: |
|
t = [] |
|
if cfg.input_size > 32: |
|
crop_pct = 1 |
|
if cfg.input_size < 384: |
|
crop_pct = 224 / 256 |
|
size = int(cfg.input_size / crop_pct) |
|
t.append( |
|
transforms.Resize( |
|
size, interpolation=3 |
|
), |
|
) |
|
t.append(transforms.CenterCrop(cfg.input_size)) |
|
|
|
t.append(transforms.ToTensor()) |
|
t.append( |
|
transforms.Normalize(cfg.normalization_mean, cfg.normalization_std) |
|
) |
|
transform = transforms.Compose(t) |
|
logger.info(transform) |
|
|
|
self.datasets[split] = ImageDataset( |
|
root=data_path, |
|
extensions=IMG_EXTENSIONS, |
|
load_classes=True, |
|
transform=transform, |
|
) |
|
for k in self.datasets.keys(): |
|
if k != split: |
|
assert self.datasets[k].classes == self.datasets[split].classes |
|
|
|
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): |
|
model = super().build_model(model_cfg, from_checkpoint) |
|
|
|
actualized_cfg = getattr(model, "cfg", None) |
|
if actualized_cfg is not None: |
|
if hasattr(actualized_cfg, "pretrained_model_args"): |
|
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args |
|
|
|
return model |
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
super().reduce_metrics(logging_outputs, criterion) |
|
|
|
if "correct" in logging_outputs[0]: |
|
zero = torch.scalar_tensor(0.0) |
|
correct = sum(log.get("correct", zero) for log in logging_outputs) |
|
metrics.log_scalar_sum("_correct", correct) |
|
|
|
metrics.log_derived( |
|
"accuracy", |
|
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum, |
|
) |
|
|