thinh-huynh-re commited on
Commit
70b3e07
1 Parent(s): 254ea49

Load id2label from json

Browse files
Files changed (1) hide show
  1. run_opencv.py +10 -0
run_opencv.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List, Optional, Tuple
2
 
3
  import cv2
@@ -32,6 +33,8 @@ class ArgParser(Tap):
32
 
33
  top_k: Optional[int] = 5
34
 
 
 
35
 
36
  class ActivityModel:
37
  def __init__(self, args: ArgParser):
@@ -41,6 +44,13 @@ class ActivityModel:
41
  self.frames_per_video = self.get_frames_per_video(args.model_name)
42
  print(f"Frames per video: {self.frames_per_video}")
43
 
 
 
 
 
 
 
 
44
  def load_model(
45
  self, model_name: str
46
  ) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]:
 
1
+ import json
2
  from typing import List, Optional, Tuple
3
 
4
  import cv2
 
33
 
34
  top_k: Optional[int] = 5
35
 
36
+ id2label: Optional[str] = "labels/kinetics_400.json"
37
+
38
 
39
  class ActivityModel:
40
  def __init__(self, args: ArgParser):
 
44
  self.frames_per_video = self.get_frames_per_video(args.model_name)
45
  print(f"Frames per video: {self.frames_per_video}")
46
 
47
+ self.load_json()
48
+
49
+ def load_json(self):
50
+ if args.id2label is not None:
51
+ with open(args.id2label, encoding="utf-8") as f:
52
+ self.model.config.id2label = json.load(f)
53
+
54
  def load_model(
55
  self, model_name: str
56
  ) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]: