Spaces:
Build error
Build error
import pytorch_lightning as pl | |
from src.ss.det_models.backbone import initialize_model | |
class POIDetection(pl.LightningModule): | |
def __init__(self, | |
n_classes, | |
**kwargs): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model, _ = initialize_model(kwargs["backbone"], | |
n_classes, | |
tune_only=kwargs["tune_fc_only"]) | |
def forward(self, images, targets=None): | |
images = list(image for image in images) | |
if targets is not None : | |
targets = [{k: v for k, v in t.items()} for t in targets] | |
outputs = self.model(images, targets) | |
else: | |
outputs = self.model(images) | |
return outputs | |