donghuna commited on
Commit
a5cd004
·
verified ·
1 Parent(s): 1f30fb0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -60
handler.py CHANGED
@@ -1,15 +1,9 @@
1
- import os
2
- import json
3
- import io
4
  import numpy as np
5
  import torch
6
- from torchvision import transforms
7
  from transformers import TimesformerForVideoClassification
8
- from ftplib import FTP
9
- import av
10
 
11
  class EndpointHandler:
12
- def __init__(self, model_dir=""):
13
  self.model = TimesformerForVideoClassification.from_pretrained(
14
  'donghuna/timesformer-base-finetuned-k400-diving48',
15
  ignore_mismatched_sizes=True
@@ -17,64 +11,14 @@ class EndpointHandler:
17
  self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
18
  self.model.eval()
19
 
20
- # Target size and number of frames
21
- self.target_size = (224, 224)
22
- self.num_frames = 24
23
-
24
  def __call__(self, data):
25
- video_path = data.get("video_path")
26
- ftp_host = data.get("ftp_host")
27
- ftp_user = data.get("ftp_user")
28
- ftp_password = data.get("ftp_password")
29
-
30
- if not all([video_path, ftp_host, ftp_user, ftp_password]):
31
- return {"error": "Missing required parameters"}
32
-
33
- # Connect to FTP and read video
34
- with FTP(self.ftp_host) as ftp:
35
- ftp.login(self.ftp_user, self.ftp_password)
36
- video_tensor = self.read_and_process_video(ftp, video_path, self.target_size, self.num_frames)
37
 
38
  # Perform inference
39
  with torch.no_grad():
40
- outputs = self.model(video_tensor.unsqueeze(0)) # Add batch dimension
41
  predictions = torch.softmax(outputs.logits, dim=-1)
42
  predicted_class = torch.argmax(predictions, dim=-1).item()
43
 
44
  return {"predicted_class": predicted_class, "predictions": predictions.tolist()}
45
-
46
- def read_video_from_ftp(self, ftp, file_path, start_frame, end_frame):
47
- video_data = io.BytesIO()
48
- ftp.retrbinary(f'RETR {file_path}', video_data.write)
49
- video_data.seek(0)
50
- container = av.open(video_data, format='mp4')
51
- frames = [frame.to_ndarray(format="rgb24").astype(np.uint8) for frame in container.decode(video=0)]
52
- return np.stack(frames, axis=0)
53
-
54
- def sample_frames(self, frames, num_frames):
55
- total_frames = len(frames)
56
- sampled_frames = list(frames)
57
- if total_frames <= num_frames:
58
- if total_frames < num_frames:
59
- padding = [np.zeros_like(frames[0]) for _ in range(num_frames - total_frames)]
60
- sampled_frames.extend(padding)
61
- else:
62
- indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int)
63
- sampled_frames = [frames[i] for i in indices]
64
- return np.array(sampled_frames)
65
-
66
- def pad_and_resize(self, frames, target_size):
67
- transform = transforms.Compose([
68
- transforms.ToPILImage(),
69
- transforms.Resize(target_size),
70
- transforms.ToTensor()
71
- ])
72
- processed_frames = [transform(frame) for frame in frames]
73
- return torch.stack(processed_frames)
74
-
75
- def read_and_process_video(self, ftp, file_path, start_frame, end_frame, target_size, num_frames):
76
- frames = self.read_video_from_ftp(ftp, file_path, start_frame, end_frame)
77
- frames = self.sample_frames(frames, num_frames=num_frames)
78
- processed_frames = self.pad_and_resize(frames, target_size=target_size)
79
- processed_frames = processed_frames.permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
80
- return processed_frames
 
 
 
 
1
  import numpy as np
2
  import torch
 
3
  from transformers import TimesformerForVideoClassification
 
 
4
 
5
  class EndpointHandler:
6
+ def __init__(self, model_dir):
7
  self.model = TimesformerForVideoClassification.from_pretrained(
8
  'donghuna/timesformer-base-finetuned-k400-diving48',
9
  ignore_mismatched_sizes=True
 
11
  self.model.classifier = torch.nn.Linear(self.model.classifier.in_features, 48) # 48 output classes
12
  self.model.eval()
13
 
 
 
 
 
14
  def __call__(self, data):
15
+ frames = np.array(data['frames'])
16
+ frames = torch.tensor(frames).float() # Ensure the data is in the correct format
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Perform inference
19
  with torch.no_grad():
20
+ outputs = self.model(frames.unsqueeze(0)) # Add batch dimension
21
  predictions = torch.softmax(outputs.logits, dim=-1)
22
  predicted_class = torch.argmax(predictions, dim=-1).item()
23
 
24
  return {"predicted_class": predicted_class, "predictions": predictions.tolist()}