Sirus1's picture
Duplicate from TencentARC/VLog
6f6830f
raw
history blame
No virus
610 Bytes
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest
import torch
from detectron2.structures.keypoints import Keypoints
class TestKeypoints(unittest.TestCase):
def test_cat_keypoints(self):
keypoints1 = Keypoints(torch.rand(2, 21, 3))
keypoints2 = Keypoints(torch.rand(4, 21, 3))
cat_keypoints = keypoints1.cat([keypoints1, keypoints2])
self.assertTrue(torch.all(cat_keypoints.tensor[:2] == keypoints1.tensor).item())
self.assertTrue(torch.all(cat_keypoints.tensor[2:] == keypoints2.tensor).item())
if __name__ == "__main__":
unittest.main()