donghuna commited on
Commit
8efce6f
·
verified ·
1 Parent(s): 2990f6d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -12
handler.py CHANGED
@@ -30,21 +30,35 @@ class EndpointHandler:
30
  A :obj:`list` | `dict`: will be serialized and returned
31
  """
32
 
 
 
 
 
 
 
 
 
33
  inputs = data.get("inputs")
34
  videos = read_video(inputs)
 
35
  with torch.no_grad():
36
  outputs = self.model(videos)
 
37
  logits = outputs.logits
38
- _, predicted = torch.max(logits, 1)
39
- return predicted.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # inputs = data.get("inputs")
42
- # if not inputs:
43
- # return {"error": "No video input provided"}
44
-
45
- # # 비디오 파일 경로
46
- # video_path = inputs.get("video_path")
47
- # if not video_path or not os.path.exists(video_path):
48
- # return {"error": "Invalid or missing video file"}
49
-
50
- # return {"predicted_class": 1}
 
30
  A :obj:`list` | `dict`: will be serialized and returned
31
  """
32
 
33
+ # inputs = data.get("inputs")
34
+ # videos = read_video(inputs)
35
+ # with torch.no_grad():
36
+ # outputs = self.model(videos)
37
+ # logits = outputs.logits
38
+ # _, predicted = torch.max(logits, 1)
39
+ # return predicted.tolist()
40
+
41
  inputs = data.get("inputs")
42
  videos = read_video(inputs)
43
+
44
  with torch.no_grad():
45
  outputs = self.model(videos)
46
+
47
  logits = outputs.logits
48
+ probabilities = torch.softmax(logits, dim=1)
49
+
50
+ # Top 3
51
+ top_probs, top_indices = torch.topk(probabilities, 3, dim=1)
52
+
53
+ top_probs_list = top_probs.tolist()
54
+ top_indices_list = top_indices.tolist()
55
+
56
+ top_results = []
57
+ for i in range(len(top_indices_list)):
58
+ top_results.append({
59
+ "class_indices": top_indices_list[i],
60
+ "probabilities": top_probs_list[i]
61
+ })
62
+
63
+ return top_results
64