Update handler.py
Browse files- handler.py +4 -60
handler.py
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import io
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from torchvision import transforms
|
7 |
from transformers import TimesformerForVideoClassification
|
8 |
-
from ftplib import FTP
|
9 |
-
import av
|
10 |
|
11 |
class EndpointHandler:
|
12 |
-
def __init__(self, model_dir
|
13 |
self.model = TimesformerForVideoClassification.from_pretrained(
|
14 |
'donghuna/timesformer-base-finetuned-k400-diving48',
|
15 |
ignore_mismatched_sizes=True
|
@@ -17,64 +11,14 @@ class EndpointHandler:
|
|
17 |
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
|
18 |
self.model.eval()
|
19 |
|
20 |
-
# Target size and number of frames
|
21 |
-
self.target_size = (224, 224)
|
22 |
-
self.num_frames = 24
|
23 |
-
|
24 |
def __call__(self, data):
|
25 |
-
|
26 |
-
|
27 |
-
ftp_user = data.get("ftp_user")
|
28 |
-
ftp_password = data.get("ftp_password")
|
29 |
-
|
30 |
-
if not all([video_path, ftp_host, ftp_user, ftp_password]):
|
31 |
-
return {"error": "Missing required parameters"}
|
32 |
-
|
33 |
-
# Connect to FTP and read video
|
34 |
-
with FTP(self.ftp_host) as ftp:
|
35 |
-
ftp.login(self.ftp_user, self.ftp_password)
|
36 |
-
video_tensor = self.read_and_process_video(ftp, video_path, self.target_size, self.num_frames)
|
37 |
|
38 |
# Perform inference
|
39 |
with torch.no_grad():
|
40 |
-
outputs = self.model(
|
41 |
predictions = torch.softmax(outputs.logits, dim=-1)
|
42 |
predicted_class = torch.argmax(predictions, dim=-1).item()
|
43 |
|
44 |
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
|
45 |
-
|
46 |
-
def read_video_from_ftp(self, ftp, file_path, start_frame, end_frame):
|
47 |
-
video_data = io.BytesIO()
|
48 |
-
ftp.retrbinary(f'RETR {file_path}', video_data.write)
|
49 |
-
video_data.seek(0)
|
50 |
-
container = av.open(video_data, format='mp4')
|
51 |
-
frames = [frame.to_ndarray(format="rgb24").astype(np.uint8) for frame in container.decode(video=0)]
|
52 |
-
return np.stack(frames, axis=0)
|
53 |
-
|
54 |
-
def sample_frames(self, frames, num_frames):
|
55 |
-
total_frames = len(frames)
|
56 |
-
sampled_frames = list(frames)
|
57 |
-
if total_frames <= num_frames:
|
58 |
-
if total_frames < num_frames:
|
59 |
-
padding = [np.zeros_like(frames[0]) for _ in range(num_frames - total_frames)]
|
60 |
-
sampled_frames.extend(padding)
|
61 |
-
else:
|
62 |
-
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
|
63 |
-
sampled_frames = [frames[i] for i in indices]
|
64 |
-
return np.array(sampled_frames)
|
65 |
-
|
66 |
-
def pad_and_resize(self, frames, target_size):
|
67 |
-
transform = transforms.Compose([
|
68 |
-
transforms.ToPILImage(),
|
69 |
-
transforms.Resize(target_size),
|
70 |
-
transforms.ToTensor()
|
71 |
-
])
|
72 |
-
processed_frames = [transform(frame) for frame in frames]
|
73 |
-
return torch.stack(processed_frames)
|
74 |
-
|
75 |
-
def read_and_process_video(self, ftp, file_path, start_frame, end_frame, target_size, num_frames):
|
76 |
-
frames = self.read_video_from_ftp(ftp, file_path, start_frame, end_frame)
|
77 |
-
frames = self.sample_frames(frames, num_frames=num_frames)
|
78 |
-
processed_frames = self.pad_and_resize(frames, target_size=target_size)
|
79 |
-
processed_frames = processed_frames.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
|
80 |
-
return processed_frames
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
|
|
3 |
from transformers import TimesformerForVideoClassification
|
|
|
|
|
4 |
|
5 |
class EndpointHandler:
|
6 |
+
def __init__(self, model_dir):
|
7 |
self.model = TimesformerForVideoClassification.from_pretrained(
|
8 |
'donghuna/timesformer-base-finetuned-k400-diving48',
|
9 |
ignore_mismatched_sizes=True
|
|
|
11 |
self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
|
12 |
self.model.eval()
|
13 |
|
|
|
|
|
|
|
|
|
14 |
def __call__(self, data):
|
15 |
+
frames = np.array(data['frames'])
|
16 |
+
frames = torch.tensor(frames).float() # Ensure the data is in the correct format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Perform inference
|
19 |
with torch.no_grad():
|
20 |
+
outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
|
21 |
predictions = torch.softmax(outputs.logits, dim=-1)
|
22 |
predicted_class = torch.argmax(predictions, dim=-1).item()
|
23 |
|
24 |
return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|