Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import unittest | |
import torch | |
from detectron2.structures.keypoints import Keypoints | |
class TestKeypoints(unittest.TestCase): | |
def test_cat_keypoints(self): | |
keypoints1 = Keypoints(torch.rand(2, 21, 3)) | |
keypoints2 = Keypoints(torch.rand(4, 21, 3)) | |
cat_keypoints = keypoints1.cat([keypoints1, keypoints2]) | |
self.assertTrue(torch.all(cat_keypoints.tensor[:2] == keypoints1.tensor).item()) | |
self.assertTrue(torch.all(cat_keypoints.tensor[2:] == keypoints2.tensor).item()) | |
if __name__ == "__main__": | |
unittest.main() | |