| from dataclasses import dataclass, field |
| from utils import parse_structure |
| from typing import Any, Dict, Mapping |
| from .base import BaseSystemConfig, BaseSystem |
| from torch import nn, Tensor |
|
|
| import os |
| import torch |
| import numpy as np |
| import models |
|
|
|
|
| @dataclass |
| class SimpleClassificationConfig(BaseSystemConfig): |
| pass |
|
|
|
|
| class SimpleClassificationSystem(BaseSystem): |
| def __init__(self, cfg: Dict, *args: Any, **kwargs: Any) -> BaseSystem: |
| super().__init__(cfg, *args, **kwargs) |
| self.cfg:SimpleClassificationConfig = parse_structure(SimpleClassificationConfig, cfg) |
| self.model: nn.Module = getattr(models, self.cfg.model_type)(self.cfg.model) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.model(x) |
|
|
| def training_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: |
| x: Tensor = batch[0] |
| y: Tensor = batch[1].float() |
| |
| y_hat: Tensor = self.model(x).squeeze(-1) |
| loss = self.criterion(y_hat, y) |
|
|
| self.log( |
| "train/loss", |
| loss, |
| on_step=self.cfg.log_on_step, |
| on_epoch=self.cfg.log_on_epoch, |
| prog_bar=self.cfg.log_prog_bar, |
| logger=self.cfg.log_logger |
| ) |
| self.log_metrics(self.metrics_func(y_hat, y, 'train')) |
|
|
| return loss |
|
|
| def validation_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: |
| x: Tensor = batch[0] |
| y: Tensor = batch[1].float() |
| |
| y_hat: Tensor = self.model(x).squeeze(-1) |
| loss = self.criterion(y_hat, y) |
|
|
| self.log( |
| "valid/loss", |
| loss, |
| on_step=self.cfg.log_on_step, |
| on_epoch=self.cfg.log_on_epoch, |
| prog_bar=self.cfg.log_prog_bar, |
| logger=self.cfg.log_logger |
| ) |
| self.log_metrics(self.metrics_func(y_hat, y, 'valid')) |
|
|
| return loss |
|
|
| def test_step(self, batch: Mapping[str, Tensor], batch_idx: int) -> Tensor: |
| x: Tensor = batch[0] |
| y: Tensor = batch[1].float() |
| |
| y_hat: Tensor = self.model(x).squeeze(-1) |
| loss = self.criterion(y_hat, y) |
|
|
| self.log( |
| "test/loss", |
| loss, |
| on_step=self.cfg.log_on_step, |
| on_epoch=self.cfg.log_on_epoch, |
| prog_bar=self.cfg.log_prog_bar, |
| logger=self.cfg.log_logger |
| ) |
| metrics_dict = self.metrics_func(y_hat, y, 'test') |
| self.log_metrics(metrics_dict) |