Juartaurus's picture
Upload folder using huggingface_hub
1865436
raw
history blame contribute delete
673 Bytes
import pytorch_lightning as pl
from det_models.model import POIDetection
from datasets_signboard_detection.datamodule import POIDataModule
from det_models.inference_signboard_detection import POIDetectionTask
def load_model(checkpoint_path):
model = POIDetection.load_from_checkpoint(checkpoint_path=checkpoint_path)
return model
def inference_signboard(image, checkpoint):
dm = POIDataModule(data=image, seed=42)
dm.setup("predict")
model = load_model(checkpoint)
task = POIDetectionTask(model)
# accelerator='gpu', devices=1
trainer = pl.Trainer(gpus=0, max_epochs=-1)
trainer.predict(task, datamodule=dm)
return task.output