| | import flash |
| | from flash.core.data.utils import download_data |
| | from flash.video import VideoClassificationData, VideoClassifier |
| | import torch |
| | from flash.video.classification.input_transform import VideoClassificationInputTransform |
| | from pytorchvideo.transforms import ( |
| | ApplyTransformToKey, |
| | ShortSideScale, |
| | UniformTemporalSubsample, |
| | UniformCropVideo, |
| | ) |
| | from dataclasses import dataclass |
| | from typing import Callable |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from flash.core.data.io.input import DataKeys |
| | from flash.core.data.io.input_transform import InputTransform |
| | from flash.core.data.transforms import ApplyToKeys |
| | from flash.core.utilities.imports import ( |
| | _KORNIA_AVAILABLE, |
| | _PYTORCHVIDEO_AVAILABLE, |
| | requires, |
| | ) |
| | from torchvision.transforms import Compose, CenterCrop |
| | from torchvision.transforms import RandomCrop |
| | from torch import nn |
| | import kornia.augmentation as K |
| | from torchvision import transforms as T |
| | torch.set_float32_matmul_precision('high') |
| |
|
| | def normalize(x: Tensor) -> Tensor: |
| | return x / 255.0 |
| |
|
| | class TransformDataModule(InputTransform): |
| | image_size: int = 256 |
| | temporal_sub_sample: int = 16 |
| | mean: Tensor = torch.tensor([0.45, 0.45, 0.45]) |
| | std: Tensor = torch.tensor([0.225, 0.225, 0.225]) |
| | data_format: str = "BCTHW" |
| | same_on_frame: bool = False |
| |
|
| | def per_sample_transform(self) -> Callable: |
| | per_sample_transform = [CenterCrop(self.image_size)] |
| |
|
| | return Compose( |
| | [ |
| | ApplyToKeys( |
| | DataKeys.INPUT, |
| | Compose( |
| | [UniformTemporalSubsample(self.temporal_sub_sample), normalize] |
| | + per_sample_transform |
| | ), |
| | ), |
| | ApplyToKeys(DataKeys.TARGET, torch.as_tensor), |
| | ] |
| | ) |
| |
|
| | def train_per_sample_transform(self) -> Callable: |
| | per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)] |
| |
|
| | return Compose( |
| | [ |
| | ApplyToKeys( |
| | DataKeys.INPUT, |
| | Compose( |
| | [UniformTemporalSubsample(self.temporal_sub_sample), normalize] |
| | + per_sample_transform |
| | ), |
| | ), |
| | ApplyToKeys(DataKeys.TARGET, torch.as_tensor), |
| | ] |
| | ) |
| |
|
| | def per_batch_transform_on_device(self) -> Callable: |
| | return ApplyToKeys( |
| | DataKeys.INPUT, |
| | K.VideoSequential( |
| | K.Normalize(self.mean, self.std), |
| | data_format=self.data_format, |
| | same_on_frame=self.same_on_frame, |
| | ), |
| | ) |
| |
|
| |
|
| |
|
| | model = VideoClassifier.load_from_checkpoint("video_classfication/checkpoints/epoch=99-step=1000.ckpt") |
| |
|
| |
|
| | datamodule_p = VideoClassificationData.from_folders( |
| | predict_folder="videos", |
| | batch_size=1, |
| | transform=TransformDataModule() |
| | ) |
| | trainer = flash.Trainer( |
| | max_epochs=5, |
| | ) |
| | def classfication(): |
| | predictions = trainer.predict(model, datamodule=datamodule_p, output="labels") |
| | return predictions[0][0] |
| |
|