xmrt commited on
Commit
22e21ad
1 Parent(s): 3bc0fc7
Files changed (2) hide show
  1. Dockerfile +2 -1
  2. main.py +6 -9
Dockerfile CHANGED
@@ -4,7 +4,8 @@ WORKDIR /code
4
  COPY ./requirements.txt /code/requirements.txt
5
 
6
  #RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.ht
7
- RUN pip install torch==1.9.1 torchvision==0.10.1 -f https://download.pytorch.org/whl/torch_stable.ht
 
8
  RUN pip install --no-cache-dir --upgrade -U openmim
9
  RUN mim install --no-cache-dir --upgrade mmengine
10
  RUN mim install "mmcv>=2.0.1"
 
4
  COPY ./requirements.txt /code/requirements.txt
5
 
6
  #RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.ht
7
+ RUN pip install torch==1.9.1 torchvision==0.10.1 -f https://download.pytorch.org/whl/cu113
8
+ #https://download.pytorch.org/whl/torch_stable.ht
9
  RUN pip install --no-cache-dir --upgrade -U openmim
10
  RUN mim install --no-cache-dir --upgrade mmengine
11
  RUN mim install "mmcv>=2.0.1"
main.py CHANGED
@@ -29,6 +29,8 @@ human3d = MMPoseInferencer(pose3d="human3d")
29
  track_model = YOLO('yolov8n.pt') # Load an official Detect model
30
 
31
  # ultraltics
 
 
32
 
33
  print("[INFO]: Downloaded models!")
34
 
@@ -38,7 +40,6 @@ def check_extension(video):
38
  # extract the file name and extension
39
  file_name = split_tup[0]
40
  file_extension = split_tup[1]
41
- print(file_extension)
42
 
43
  if file_extension != ".mp4":
44
  print("Converting to mp4")
@@ -50,9 +51,6 @@ def check_extension(video):
50
  return video
51
 
52
 
53
-
54
-
55
-
56
  def tracking(video, model, boxes=True):
57
  print("[INFO] Is cuda available? ", torch.cuda.is_available())
58
  print("[INFO] Loading model...")
@@ -61,12 +59,11 @@ def tracking(video, model, boxes=True):
61
  # Perform tracking with the model
62
  print("[INFO] Starting tracking!")
63
  # https://docs.ultralytics.com/modes/predict/
64
- annotated_frame = model(video, boxes=boxes)
65
 
66
  return annotated_frame
67
 
68
  def show_tracking(video_content):
69
- print()
70
 
71
  # https://docs.ultralytics.com/datasets/detect/coco/
72
  video = cv2.VideoCapture(video_content)
@@ -116,7 +113,7 @@ def pose3d(video):
116
  thickness=2,
117
  return_vis=True,
118
  rebase_keypoint_height=True,
119
- device="cuda")
120
 
121
  result = [result for result in result_generator] #next(result_generator)
122
 
@@ -139,7 +136,7 @@ def pose2d(video, kpt_threshold):
139
  thickness=2,
140
  rebase_keypoint_height=True,
141
  kpt_thr=kpt_threshold,
142
- device="cuda"
143
  )
144
 
145
  result = [result for result in result_generator] #next(result_generator)
@@ -163,7 +160,7 @@ def pose2dhand(video, kpt_threshold):
163
  thickness=2,
164
  rebase_keypoint_height=True,
165
  kpt_thr=kpt_threshold,
166
- device="cuda")
167
 
168
  result = [result for result in result_generator] #next(result_generator)
169
 
 
29
  track_model = YOLO('yolov8n.pt') # Load an official Detect model
30
 
31
  # ultraltics
32
+ if torch.cuda.is_available():
33
+ device = "cuda"
34
 
35
  print("[INFO]: Downloaded models!")
36
 
 
40
  # extract the file name and extension
41
  file_name = split_tup[0]
42
  file_extension = split_tup[1]
 
43
 
44
  if file_extension != ".mp4":
45
  print("Converting to mp4")
 
51
  return video
52
 
53
 
 
 
 
54
  def tracking(video, model, boxes=True):
55
  print("[INFO] Is cuda available? ", torch.cuda.is_available())
56
  print("[INFO] Loading model...")
 
59
  # Perform tracking with the model
60
  print("[INFO] Starting tracking!")
61
  # https://docs.ultralytics.com/modes/predict/
62
+ annotated_frame = model(video, boxes=boxes, device=device)
63
 
64
  return annotated_frame
65
 
66
  def show_tracking(video_content):
 
67
 
68
  # https://docs.ultralytics.com/datasets/detect/coco/
69
  video = cv2.VideoCapture(video_content)
 
113
  thickness=2,
114
  return_vis=True,
115
  rebase_keypoint_height=True,
116
+ device=device)
117
 
118
  result = [result for result in result_generator] #next(result_generator)
119
 
 
136
  thickness=2,
137
  rebase_keypoint_height=True,
138
  kpt_thr=kpt_threshold,
139
+ device=device
140
  )
141
 
142
  result = [result for result in result_generator] #next(result_generator)
 
160
  thickness=2,
161
  rebase_keypoint_height=True,
162
  kpt_thr=kpt_threshold,
163
+ device=device)
164
 
165
  result = [result for result in result_generator] #next(result_generator)
166