from typing import Dict, List, Optional, Tuple | |
import timm | |
import torch | |
from pytorch_lightning import LightningModule | |
class TemplateClassifier(LightningModule): | |
def __init__(self, config: dict): | |
super().__init__() | |
# NN architecture | |
self.backbone = timm.create_model( | |
#SPECIFY HERE YOUR MODEL | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# WRITE YOU CODE HERE | |
predictions=None | |
return predictions | |