import timm import torch from torch import nn import pytorch_lightning as pl from pytorch_lightning.core.mixins import HyperparametersMixin class Model200M(torch.nn.Module): def __init__(self): super().__init__() self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, num_classes=0) self.clf = nn.Sequential( nn.Linear(1536, 128), nn.ReLU(inplace=True), nn.Linear(128, 2)) def forward(self, image): image_features = self.model(image) return self.clf(image_features) class Model5M(torch.nn.Module): def __init__(self): super().__init__() self.model = timm.create_model('timm/tf_mobilenetv3_large_100.in1k', pretrained=False, num_classes=0) self.clf = nn.Sequential( nn.Linear(1280, 128), nn.ReLU(inplace=True), nn.Linear(128, 2)) def forward(self, image): image_features = self.model(image) return self.clf(image_features) class SyntheticV2(pl.LightningModule, HyperparametersMixin): def __init__(self): super().__init__() self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', pretrained=False, num_classes=0) self.clf = nn.Sequential( nn.Linear(1536, 128), nn.ReLU(inplace=True), nn.Linear(128, 2)) def forward(self, image): image_features = self.model(image) return self.clf(image_features)