nikkar commited on
Commit
d6d3990
1 Parent(s): c60875a

cotracker 2.0

Browse files
Files changed (2) hide show
  1. app.py +31 -31
  2. requirements.txt +1 -5
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
 
11
  from cotracker.utils.visualizer import Visualizer
12
 
 
13
  def parse_video(video_file):
14
  vs = cv2.VideoCapture(video_file)
15
 
@@ -26,33 +27,33 @@ def parse_video(video_file):
26
 
27
 
28
  def cotracker_demo(
29
- input_video,
30
- grid_size: int = 10,
31
- grid_query_frame: int = 0,
32
- backward_tracking: bool = False,
33
- tracks_leave_trace: bool = False
34
- ):
35
  load_video = parse_video(input_video)
36
- grid_query_frame = min(len(load_video)-1, grid_query_frame)
37
  load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()
38
 
39
- model = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
40
 
41
  if torch.cuda.is_available():
42
  model = model.cuda()
43
  load_video = load_video.cuda()
44
- pred_tracks, pred_visibility = model(
45
- load_video,
46
- grid_size=grid_size,
47
- grid_query_frame=grid_query_frame,
48
- backward_tracking=backward_tracking
49
- )
 
50
  linewidth = 2
51
  if grid_size < 10:
52
  linewidth = 4
53
  elif grid_size < 20:
54
  linewidth = 3
55
-
56
  vis = Visualizer(
57
  save_dir=os.path.join(os.path.dirname(__file__), "results"),
58
  grayscale=False,
@@ -60,7 +61,7 @@ def cotracker_demo(
60
  fps=10,
61
  linewidth=linewidth,
62
  show_first_frame=5,
63
- tracks_leave_trace= -1 if tracks_leave_trace else 0,
64
  )
65
  import time
66
 
@@ -70,53 +71,52 @@ def cotracker_demo(
70
  filename = str(current_milli_time())
71
  vis.visualize(
72
  load_video.cpu(),
73
- tracks=pred_tracks.cpu(),
74
  visibility=pred_visibility.cpu(),
75
- filename=filename,
76
  query_frame=grid_query_frame,
77
- )
78
  return os.path.join(
79
  os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
80
  )
81
 
 
82
  apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
83
  bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
84
- paragliding_launch = os.path.join(os.path.dirname(__file__), "videos", "paragliding-launch.mp4")
 
 
85
  paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
86
 
87
  app = gr.Interface(
88
- title = "🎨 CoTracker: It is Better to Track Together",
89
- description = "<div style='text-align: left;'> \
90
  <p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
91
  Points are sampled on a regular grid and are tracked jointly. </p> \
92
  <p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
93
  <ul style='display: inline-block; text-align: left;'> \
94
  <li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
95
  <li>To specify the starting frame for tracking, adjust <b>Grid Query Frame</b>. Tracks will be visualized only after the selected frame.</li> \
96
- <li>Use <b>Backward Tracking</b> to track points from the selected frame in both directions.</li> \
97
  <li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
98
  </ul> \
99
  <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \
100
  </div>",
101
-
102
  fn=cotracker_demo,
103
  inputs=[
104
  gr.Video(type="file", label="Input video", interactive=True),
105
  gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
106
  gr.Slider(minimum=0, maximum=30, step=1, default=0, label="Grid Query Frame"),
107
- gr.Checkbox(label="Backward Tracking"),
108
  gr.Checkbox(label="Visualize Track Traces"),
109
  ],
110
  outputs=gr.Video(label="Video with predicted tracks"),
111
  examples=[
112
- [ apple, 10, 0, False, False ],
113
- [ apple, 20, 30, True, False ],
114
- [ bear, 10, 0, False, False ],
115
- [ paragliding, 10, 0, False, False ],
116
- [ paragliding_launch, 10, 0, False, False ],
117
  ],
118
  cache_examples=True,
119
  allow_flagging=False,
120
-
121
  )
122
  app.queue(max_size=20, concurrency_count=2).launch(debug=True)
 
10
 
11
  from cotracker.utils.visualizer import Visualizer
12
 
13
+
14
  def parse_video(video_file):
15
  vs = cv2.VideoCapture(video_file)
16
 
 
27
 
28
 
29
  def cotracker_demo(
30
+ input_video,
31
+ grid_size: int = 10,
32
+ grid_query_frame: int = 0,
33
+ tracks_leave_trace: bool = False,
34
+ ):
 
35
  load_video = parse_video(input_video)
36
+ grid_query_frame = min(len(load_video) - 1, grid_query_frame)
37
  load_video = torch.from_numpy(load_video).permute(0, 3, 1, 2)[None].float()
38
 
39
+ model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
40
 
41
  if torch.cuda.is_available():
42
  model = model.cuda()
43
  load_video = load_video.cuda()
44
+
45
+ model(video_chunk=load_video, is_first_step=True, grid_size=grid_size)
46
+ for ind in range(0, load_video.shape[1] - model.step, model.step):
47
+ pred_tracks, pred_visibility = model(
48
+ video_chunk=load_video[:, ind : ind + model.step * 2]
49
+ ) # B T N 2, B T N 1
50
+
51
  linewidth = 2
52
  if grid_size < 10:
53
  linewidth = 4
54
  elif grid_size < 20:
55
  linewidth = 3
56
+
57
  vis = Visualizer(
58
  save_dir=os.path.join(os.path.dirname(__file__), "results"),
59
  grayscale=False,
 
61
  fps=10,
62
  linewidth=linewidth,
63
  show_first_frame=5,
64
+ tracks_leave_trace=-1 if tracks_leave_trace else 0,
65
  )
66
  import time
67
 
 
71
  filename = str(current_milli_time())
72
  vis.visualize(
73
  load_video.cpu(),
74
+ tracks=pred_tracks.cpu(),
75
  visibility=pred_visibility.cpu(),
76
+ filename=f"{filename}_pred_track",
77
  query_frame=grid_query_frame,
78
+ )
79
  return os.path.join(
80
  os.path.dirname(__file__), "results", f"{filename}_pred_track.mp4"
81
  )
82
 
83
+
84
  apple = os.path.join(os.path.dirname(__file__), "videos", "apple.mp4")
85
  bear = os.path.join(os.path.dirname(__file__), "videos", "bear.mp4")
86
+ paragliding_launch = os.path.join(
87
+ os.path.dirname(__file__), "videos", "paragliding-launch.mp4"
88
+ )
89
  paragliding = os.path.join(os.path.dirname(__file__), "videos", "paragliding.mp4")
90
 
91
  app = gr.Interface(
92
+ title="🎨 CoTracker: It is Better to Track Together",
93
+ description="<div style='text-align: left;'> \
94
  <p>Welcome to <a href='http://co-tracker.github.io' target='_blank'>CoTracker</a>! This space demonstrates point (pixel) tracking in videos. \
95
  Points are sampled on a regular grid and are tracked jointly. </p> \
96
  <p> To get started, simply upload your <b>.mp4</b> video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length <b>2-7 seconds</b>.</p> \
97
  <ul style='display: inline-block; text-align: left;'> \
98
  <li>The total number of grid points is the square of <b>Grid Size</b>.</li> \
99
  <li>To specify the starting frame for tracking, adjust <b>Grid Query Frame</b>. Tracks will be visualized only after the selected frame.</li> \
 
100
  <li>Check <b>Visualize Track Traces</b> to visualize traces of all the tracked points. </li> \
101
  </ul> \
102
  <p style='text-align: left'>For more details, check out our <a href='https://github.com/facebookresearch/co-tracker' target='_blank'>GitHub Repo</a> ⭐</p> \
103
  </div>",
 
104
  fn=cotracker_demo,
105
  inputs=[
106
  gr.Video(type="file", label="Input video", interactive=True),
107
  gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
108
  gr.Slider(minimum=0, maximum=30, step=1, default=0, label="Grid Query Frame"),
 
109
  gr.Checkbox(label="Visualize Track Traces"),
110
  ],
111
  outputs=gr.Video(label="Video with predicted tracks"),
112
  examples=[
113
+ [apple, 10, 0, False, False],
114
+ [apple, 20, 30, True, False],
115
+ [bear, 10, 0, False, False],
116
+ [paragliding, 10, 0, False, False],
117
+ [paragliding_launch, 10, 0, False, False],
118
  ],
119
  cache_examples=True,
120
  allow_flagging=False,
 
121
  )
122
  app.queue(max_size=20, concurrency_count=2).launch(debug=True)
requirements.txt CHANGED
@@ -1,11 +1,7 @@
1
- einops
2
- timm
3
- tqdm
4
- opencv-python
5
  matplotlib
6
- moviepy
7
  flow_vis
8
  imutils
9
  numpy
 
10
  gradio
11
  git+https://github.com/facebookresearch/co-tracker.git
 
 
 
 
 
1
  matplotlib
 
2
  flow_vis
3
  imutils
4
  numpy
5
+ imageio
6
  gradio
7
  git+https://github.com/facebookresearch/co-tracker.git