import os import json import io import numpy as np import torch from torchvision import transforms from transformers import TimesformerForVideoClassification from ftplib import FTP import av class EndpointHandler: def __init__(self, model_dir=""): self.model = TimesformerForVideoClassification.from_pretrained( 'donghuna/timesformer-base-finetuned-k400-diving48', ignore_mismatched_sizes=True ) self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes self.model.eval() # Target size and number of frames self.target_size = (224, 224) self.num_frames = 24 def __call__(self, data): video_path = data.get("video_path") ftp_host = data.get("ftp_host") ftp_user = data.get("ftp_user") ftp_password = data.get("ftp_password") if not all([video_path, ftp_host, ftp_user, ftp_password]): return {"error": "Missing required parameters"} # Connect to FTP and read video with FTP(self.ftp_host) as ftp: ftp.login(self.ftp_user, self.ftp_password) video_tensor = self.read_and_process_video(ftp, video_path, self.target_size, self.num_frames) # Perform inference with torch.no_grad(): outputs = self.model(video_tensor.unsqueeze(0)) # Add batch dimension predictions = torch.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=-1).item() return {"predicted_class": predicted_class, "predictions": predictions.tolist()} def read_video_from_ftp(self, ftp, file_path, start_frame, end_frame): video_data = io.BytesIO() ftp.retrbinary(f'RETR {file_path}', video_data.write) video_data.seek(0) container = av.open(video_data, format='mp4') frames = [frame.to_ndarray(format="rgb24").astype(np.uint8) for frame in container.decode(video=0)] return np.stack(frames, axis=0) def sample_frames(self, frames, num_frames): total_frames = len(frames) sampled_frames = list(frames) if total_frames <= num_frames: if total_frames < num_frames: padding = [np.zeros_like(frames[0]) for _ in range(num_frames - total_frames)] sampled_frames.extend(padding) else: indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) sampled_frames = [frames[i] for i in indices] return np.array(sampled_frames) def pad_and_resize(self, frames, target_size): transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(target_size), transforms.ToTensor() ]) processed_frames = [transform(frame) for frame in frames] return torch.stack(processed_frames) def read_and_process_video(self, ftp, file_path, start_frame, end_frame, target_size, num_frames): frames = self.read_video_from_ftp(ftp, file_path, start_frame, end_frame) frames = self.sample_frames(frames, num_frames=num_frames) processed_frames = self.pad_and_resize(frames, target_size=target_size) processed_frames = processed_frames.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W) return processed_frames