Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import numpy as np | |
from torch.utils.data import DataLoader | |
import polars as pl | |
import lightning as L | |
from data_utils.frame_dataset import FrameDataset | |
import torch | |
from models.lightning_wrapper import LightningWrapper | |
def run_inference( | |
model_path: Path, | |
image_folder: Path, | |
aggregate_duration: int = 30, | |
fps: int = 3, | |
) -> pl.DataFrame: | |
model = LightningWrapper.load_from_checkpoint(model_path) | |
trainer = L.Trainer() | |
paths = list(image_folder.rglob("*.jpg")) | |
df = pl.DataFrame( | |
{"path": paths, "frame": [int(p.stem.removeprefix("img")) for p in paths]} | |
).sort("frame") | |
ds = FrameDataset(df, model.get_transforms(is_training=False), 1, is_train=False) | |
dls = DataLoader(ds, batch_size=32, num_workers=2, pin_memory=True) | |
preds_list: list[torch.Tensor] = trainer.predict(model, dataloaders=dls) # type: ignore | |
preds = torch.cat(preds_list) | |
pred_class = torch.argmax(preds, dim=1) | |
preds_class = np.repeat(pred_class.numpy(), ds.frames_per_clip) | |
df = df.with_columns(preds=pl.Series(preds_class)) | |
df_g = df.group_by(pl.col("frame") // (aggregate_duration * fps)).agg( | |
pl.sum("preds") | |
) | |
seconds = pl.col("frame") | |
df_g = ( | |
df_g.with_columns(pl.col("frame") * aggregate_duration) | |
.with_columns( | |
hour=seconds // (60 * 60), minute=(seconds // 60) % 60, second=seconds % 60 | |
) | |
.with_columns( | |
timestamp=pl.datetime( | |
year=2023, | |
month=12, | |
day=10, | |
hour=pl.col("hour"), | |
minute="minute", | |
second="second", | |
) | |
) | |
.sort("timestamp") | |
) | |
return df_g | |