File size: 673 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import pytorch_lightning as pl
from det_models.model import POIDetection
from datasets_signboard_detection.datamodule import POIDataModule
from det_models.inference_signboard_detection import POIDetectionTask

def load_model(checkpoint_path):
    model = POIDetection.load_from_checkpoint(checkpoint_path=checkpoint_path)
    return model


def inference_signboard(image, checkpoint):

    dm = POIDataModule(data=image, seed=42)
    dm.setup("predict")

    model = load_model(checkpoint)
    task = POIDetectionTask(model)

    # accelerator='gpu', devices=1
    trainer = pl.Trainer(gpus=0, max_epochs=-1)
    trainer.predict(task, datamodule=dm)
    return task.output