sino72 commited on
Commit
bbde013
1 Parent(s): 61d77aa

add deepsort

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. deep_sort/configs/deep_sort.yaml +10 -0
  3. deep_sort/deep_sort/README.md +3 -0
  4. deep_sort/deep_sort/__init__.py +21 -0
  5. deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc +0 -0
  6. deep_sort/deep_sort/__pycache__/__init__.cpython-38.pyc +0 -0
  7. deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc +0 -0
  8. deep_sort/deep_sort/__pycache__/deep_sort.cpython-38.pyc +0 -0
  9. deep_sort/deep_sort/deep/__init__.py +0 -0
  10. deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc +0 -0
  11. deep_sort/deep_sort/deep/__pycache__/__init__.cpython-38.pyc +0 -0
  12. deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc +0 -0
  13. deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc +0 -0
  14. deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc +0 -0
  15. deep_sort/deep_sort/deep/__pycache__/model.cpython-38.pyc +0 -0
  16. deep_sort/deep_sort/deep/checkpoint/ckpt.t7 +3 -0
  17. deep_sort/deep_sort/deep/evaluate.py +15 -0
  18. deep_sort/deep_sort/deep/feature_extractor.py +65 -0
  19. deep_sort/deep_sort/deep/model.py +105 -0
  20. deep_sort/deep_sort/deep/original_model.py +106 -0
  21. deep_sort/deep_sort/deep/prepare_car.py +129 -0
  22. deep_sort/deep_sort/deep/prepare_person.py +108 -0
  23. deep_sort/deep_sort/deep/test.py +77 -0
  24. deep_sort/deep_sort/deep/train.jpg +0 -0
  25. deep_sort/deep_sort/deep/train.py +192 -0
  26. deep_sort/deep_sort/deep_sort.py +125 -0
  27. deep_sort/deep_sort/sort/__init__.py +0 -0
  28. deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc +0 -0
  29. deep_sort/deep_sort/sort/__pycache__/__init__.cpython-38.pyc +0 -0
  30. deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc +0 -0
  31. deep_sort/deep_sort/sort/__pycache__/detection.cpython-38.pyc +0 -0
  32. deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc +0 -0
  33. deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc +0 -0
  34. deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc +0 -0
  35. deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc +0 -0
  36. deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc +0 -0
  37. deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc +0 -0
  38. deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc +0 -0
  39. deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc +0 -0
  40. deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc +0 -0
  41. deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-38.pyc +0 -0
  42. deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc +0 -0
  43. deep_sort/deep_sort/sort/__pycache__/track.cpython-38.pyc +0 -0
  44. deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc +0 -0
  45. deep_sort/deep_sort/sort/__pycache__/tracker.cpython-38.pyc +0 -0
  46. deep_sort/deep_sort/sort/detection.py +49 -0
  47. deep_sort/deep_sort/sort/iou_matching.py +84 -0
  48. deep_sort/deep_sort/sort/kalman_filter.py +286 -0
  49. deep_sort/deep_sort/sort/linear_assignment.py +240 -0
  50. deep_sort/deep_sort/sort/nn_matching.py +207 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
deep_sort/configs/deep_sort.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ DEEPSORT:
2
+ REID_CKPT: "deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
3
+ MAX_DIST: 0.2
4
+ MIN_CONFIDENCE: 0.3
5
+ NMS_MAX_OVERLAP: 0.5
6
+ MAX_IOU_DISTANCE: 0.7
7
+ MAX_AGE: 70
8
+ N_INIT: 3
9
+ NN_BUDGET: 100
10
+
deep_sort/deep_sort/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Deep Sort
2
+
3
+ This is the implemention of deep sort with pytorch.
deep_sort/deep_sort/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .deep_sort import DeepSort
2
+
3
+
4
+ __all__ = ['DeepSort', 'build_tracker']
5
+
6
+
7
+ def build_tracker(cfg, use_cuda):
8
+ return DeepSort(cfg.DEEPSORT.REID_CKPT,
9
+ max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
10
+ nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
11
+ max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda)
12
+
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (599 Bytes). View file
 
deep_sort/deep_sort/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (607 Bytes). View file
 
deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
deep_sort/deep_sort/__pycache__/deep_sort.cpython-38.pyc ADDED
Binary file (4.15 kB). View file
 
deep_sort/deep_sort/deep/__init__.py ADDED
File without changes
deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
deep_sort/deep_sort/deep/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (146 Bytes). View file
 
deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.8 kB). View file
 
deep_sort/deep_sort/deep/__pycache__/model.cpython-38.pyc ADDED
Binary file (2.78 kB). View file
 
deep_sort/deep_sort/deep/checkpoint/ckpt.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22628596f112dc7eb1fe7adfbfaf95bbc6ce8eb024205beafdc705232a646c29
3
+ size 46061055
deep_sort/deep_sort/deep/evaluate.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ features = torch.load("features.pth")
4
+ qf = features["qf"]
5
+ ql = features["ql"]
6
+ gf = features["gf"]
7
+ gl = features["gl"]
8
+
9
+ scores = qf.mm(gf.t())
10
+ res = scores.topk(5, dim=1)[1][:,0]
11
+ top1correct = gl[res].eq(ql).sum().item()
12
+
13
+ print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
14
+
15
+
deep_sort/deep_sort/deep/feature_extractor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import numpy as np
4
+ import cv2
5
+ import logging
6
+
7
+ from .model import Net
8
+
9
+ '''
10
+ 特征提取器:
11
+ 提取对应bounding box中的特征, 得到一个固定维度的embedding作为该bounding box的代表,
12
+ 供计算相似度时使用。
13
+
14
+ 模型训练是按照传统ReID的方法进行,使用Extractor类的时候输入为一个list的图片,得到图片对应的特征。
15
+ '''
16
+
17
+ class Extractor(object):
18
+ def __init__(self, model_path, use_cuda=True):
19
+ self.net = Net(reid=True)
20
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
21
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
22
+ self.net.load_state_dict(state_dict)
23
+ logger = logging.getLogger("root.tracker")
24
+ logger.info("Loading weights from {}... Done!".format(model_path))
25
+ self.net.to(self.device)
26
+ self.size = (64, 128)
27
+ self.norm = transforms.Compose([
28
+ # RGB图片数据范围是[0-255],需要先经过ToTensor除以255归一化到[0,1]之后,
29
+ # 再通过Normalize计算(x - mean)/std后,将数据归一化到[-1,1]。
30
+ transforms.ToTensor(),
31
+ # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]是从imagenet训练集中算出来的
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ def _preprocess(self, im_crops):
36
+ """
37
+ TODO:
38
+ 1. to float with scale from 0 to 1
39
+ 2. resize to (64, 128) as Market1501 dataset did
40
+ 3. concatenate to a numpy array
41
+ 3. to torch Tensor
42
+ 4. normalize
43
+ """
44
+ def _resize(im, size):
45
+ return cv2.resize(im.astype(np.float32)/255., size)
46
+
47
+ im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
48
+ return im_batch
49
+
50
+ # __call__()是一个非常特殊的实例方法。该方法的功能类似于在类中重载 () 运算符,
51
+ # 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。
52
+ def __call__(self, im_crops):
53
+ im_batch = self._preprocess(im_crops)
54
+ with torch.no_grad():
55
+ im_batch = im_batch.to(self.device)
56
+ features = self.net(im_batch)
57
+ return features.cpu().numpy()
58
+
59
+
60
+ if __name__ == '__main__':
61
+ img = cv2.imread("demo.jpg")[:,:,(2,1,0)]
62
+ extr = Extractor("checkpoint/ckpt.t7")
63
+ feature = extr(img)
64
+ print(feature.shape)
65
+
deep_sort/deep_sort/deep/model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BasicBlock(nn.Module):
6
+ def __init__(self, c_in, c_out,is_downsample=False):
7
+ super(BasicBlock,self).__init__()
8
+ self.is_downsample = is_downsample
9
+ if is_downsample:
10
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
11
+ else:
12
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(c_out)
14
+ self.relu = nn.ReLU(True)
15
+ self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
16
+ self.bn2 = nn.BatchNorm2d(c_out)
17
+ if is_downsample:
18
+ self.downsample = nn.Sequential(
19
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
20
+ nn.BatchNorm2d(c_out)
21
+ )
22
+ elif c_in != c_out:
23
+ self.downsample = nn.Sequential(
24
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
25
+ nn.BatchNorm2d(c_out)
26
+ )
27
+ self.is_downsample = True
28
+
29
+ def forward(self,x):
30
+ y = self.conv1(x)
31
+ y = self.bn1(y)
32
+ y = self.relu(y)
33
+ y = self.conv2(y)
34
+ y = self.bn2(y)
35
+ if self.is_downsample:
36
+ x = self.downsample(x)
37
+ return F.relu(x.add(y),True)
38
+
39
+ def make_layers(c_in,c_out,repeat_times, is_downsample=False):
40
+ blocks = []
41
+ for i in range(repeat_times):
42
+ if i ==0:
43
+ blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
44
+ else:
45
+ blocks += [BasicBlock(c_out,c_out),]
46
+ return nn.Sequential(*blocks)
47
+
48
+ class Net(nn.Module):
49
+ def __init__(self, num_classes=751, reid=False):
50
+ super(Net,self).__init__()
51
+ # 3 128 64
52
+ self.conv = nn.Sequential(
53
+ nn.Conv2d(3,64,3,stride=1,padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(inplace=True),
56
+ # nn.Conv2d(32,32,3,stride=1,padding=1),
57
+ # nn.BatchNorm2d(32),
58
+ # nn.ReLU(inplace=True),
59
+ nn.MaxPool2d(3,2,padding=1),
60
+ )
61
+ # 32 64 32
62
+ self.layer1 = make_layers(64,64,2,False)
63
+ # 32 64 32
64
+ self.layer2 = make_layers(64,128,2,True)
65
+ # 64 32 16
66
+ self.layer3 = make_layers(128,256,2,True)
67
+ # 128 16 8
68
+ self.layer4 = make_layers(256,512,2,True)
69
+ # 256 8 4
70
+ self.avgpool = nn.AvgPool2d((8,4),1)
71
+ # 256 1 1
72
+ self.reid = reid
73
+
74
+ self.classifier = nn.Sequential(
75
+ nn.Linear(512, 256),
76
+ nn.BatchNorm1d(256),
77
+ nn.ReLU(inplace=True),
78
+ nn.Dropout(),
79
+ nn.Linear(256, num_classes),
80
+ )
81
+
82
+ def forward(self, x):
83
+ x = self.conv(x)
84
+ x = self.layer1(x)
85
+ x = self.layer2(x)
86
+ x = self.layer3(x)
87
+ x = self.layer4(x)
88
+ x = self.avgpool(x)
89
+ x = x.view(x.size(0),-1)
90
+ # B x 128
91
+ if self.reid:
92
+ x = x.div(x.norm(p=2,dim=1,keepdim=True))
93
+ return x
94
+ # classifier
95
+ x = self.classifier(x)
96
+ return x
97
+
98
+
99
+ if __name__ == '__main__':
100
+ net = Net()
101
+ x = torch.randn(4,3,128,64)
102
+ y = net(x)
103
+ import ipdb; ipdb.set_trace()
104
+
105
+
deep_sort/deep_sort/deep/original_model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BasicBlock(nn.Module):
6
+ def __init__(self, c_in, c_out,is_downsample=False):
7
+ super(BasicBlock,self).__init__()
8
+ self.is_downsample = is_downsample
9
+ if is_downsample:
10
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
11
+ else:
12
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(c_out)
14
+ self.relu = nn.ReLU(True)
15
+ self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
16
+ self.bn2 = nn.BatchNorm2d(c_out)
17
+ if is_downsample:
18
+ self.downsample = nn.Sequential(
19
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
20
+ nn.BatchNorm2d(c_out)
21
+ )
22
+ elif c_in != c_out:
23
+ self.downsample = nn.Sequential(
24
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
25
+ nn.BatchNorm2d(c_out)
26
+ )
27
+ self.is_downsample = True
28
+
29
+ def forward(self,x):
30
+ y = self.conv1(x)
31
+ y = self.bn1(y)
32
+ y = self.relu(y)
33
+ y = self.conv2(y)
34
+ y = self.bn2(y)
35
+ if self.is_downsample:
36
+ x = self.downsample(x)
37
+ return F.relu(x.add(y),True)
38
+
39
+ def make_layers(c_in,c_out,repeat_times, is_downsample=False):
40
+ blocks = []
41
+ for i in range(repeat_times):
42
+ if i ==0:
43
+ blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
44
+ else:
45
+ blocks += [BasicBlock(c_out,c_out),]
46
+ return nn.Sequential(*blocks)
47
+
48
+ class Net(nn.Module):
49
+ def __init__(self, num_classes=625 ,reid=False):
50
+ super(Net,self).__init__()
51
+ # 3 128 64
52
+ self.conv = nn.Sequential(
53
+ nn.Conv2d(3,32,3,stride=1,padding=1),
54
+ nn.BatchNorm2d(32),
55
+ nn.ELU(inplace=True),
56
+ nn.Conv2d(32,32,3,stride=1,padding=1),
57
+ nn.BatchNorm2d(32),
58
+ nn.ELU(inplace=True),
59
+ nn.MaxPool2d(3,2,padding=1),
60
+ )
61
+ # 32 64 32
62
+ self.layer1 = make_layers(32,32,2,False)
63
+ # 32 64 32
64
+ self.layer2 = make_layers(32,64,2,True)
65
+ # 64 32 16
66
+ self.layer3 = make_layers(64,128,2,True)
67
+ # 128 16 8
68
+ self.dense = nn.Sequential(
69
+ nn.Dropout(p=0.6),
70
+ nn.Linear(128*16*8, 128),
71
+ nn.BatchNorm1d(128),
72
+ nn.ELU(inplace=True)
73
+ )
74
+ # 256 1 1
75
+ self.reid = reid
76
+ self.batch_norm = nn.BatchNorm1d(128)
77
+ self.classifier = nn.Sequential(
78
+ nn.Linear(128, num_classes),
79
+ )
80
+
81
+ def forward(self, x):
82
+ x = self.conv(x)
83
+ x = self.layer1(x)
84
+ x = self.layer2(x)
85
+ x = self.layer3(x)
86
+
87
+ x = x.view(x.size(0),-1)
88
+ if self.reid:
89
+ x = self.dense[0](x)
90
+ x = self.dense[1](x)
91
+ x = x.div(x.norm(p=2,dim=1,keepdim=True))
92
+ return x
93
+ x = self.dense(x)
94
+ # B x 128
95
+ # classifier
96
+ x = self.classifier(x)
97
+ return x
98
+
99
+
100
+ if __name__ == '__main__':
101
+ net = Net(reid=True)
102
+ x = torch.randn(4,3,128,64)
103
+ y = net(x)
104
+ import ipdb; ipdb.set_trace()
105
+
106
+
deep_sort/deep_sort/deep/prepare_car.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf8 -*-
2
+
3
+ import os
4
+ from PIL import Image
5
+ from shutil import copyfile, copytree, rmtree, move
6
+
7
+ PATH_DATASET = './car-dataset' # 需要处理的文件夹
8
+ PATH_NEW_DATASET = './car-reid-dataset' # 处理后的文件夹
9
+ PATH_ALL_IMAGES = PATH_NEW_DATASET + '/all_images'
10
+ PATH_TRAIN = PATH_NEW_DATASET + '/train'
11
+ PATH_TEST = PATH_NEW_DATASET + '/test'
12
+
13
+ # 定义创建目录函数
14
+ def mymkdir(path):
15
+ path = path.strip() # 去除首位空格
16
+ path = path.rstrip("\\") # 去除尾部 \ 符号
17
+ isExists = os.path.exists(path) # 判断路径是否存在
18
+ if not isExists:
19
+ os.makedirs(path) # 如果不存在则创建目录
20
+ print(path + ' 创建成功')
21
+ return True
22
+ else:
23
+ # 如果目录存在则不创建,并提示目录已存在
24
+ print(path + ' 目录已存在')
25
+ return False
26
+
27
+ class BatchRename():
28
+ '''
29
+ 批量重命名文件夹中的图片文件
30
+ '''
31
+
32
+ def __init__(self):
33
+ self.path = PATH_DATASET # 表示需要命名处理的文件夹
34
+
35
+ # 修改图像尺寸
36
+ def resize(self):
37
+ for aroot, dirs, files in os.walk(self.path):
38
+ # aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
39
+ filelist = files # 注意此处仅是该路径下的其中一个列表
40
+ # print('list', list)
41
+
42
+ # filelist = os.listdir(self.path) #获取文件路径
43
+ total_num = len(filelist) # 获取文件长度(个数)
44
+
45
+ for item in filelist:
46
+ if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
47
+ src = os.path.join(os.path.abspath(aroot), item)
48
+
49
+ # 修改图片尺寸到128宽*256高
50
+ im = Image.open(src)
51
+ out = im.resize((128, 256), Image.ANTIALIAS) # resize image with high-quality
52
+ out.save(src) # 原路径保存
53
+
54
+ def rename(self):
55
+
56
+ for aroot, dirs, files in os.walk(self.path):
57
+ # aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
58
+ filelist = files # 注意此处仅是该路径下的其中一个列表
59
+ # print('list', list)
60
+
61
+ # filelist = os.listdir(self.path) #获取文件路径
62
+ total_num = len(filelist) # 获取文件长度(个数)
63
+
64
+ i = 1 # 表示文件的命名是从1开始的
65
+ for item in filelist:
66
+ if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
67
+ src = os.path.join(os.path.abspath(aroot), item)
68
+
69
+ # 根据图片名创建图片目录
70
+ dirname = str(item.split('_')[0])
71
+ # 为相同车辆创建目录
72
+ #new_dir = os.path.join(self.path, '..', 'bbox_all', dirname)
73
+ new_dir = os.path.join(PATH_ALL_IMAGES, dirname)
74
+ if not os.path.isdir(new_dir):
75
+ mymkdir(new_dir)
76
+
77
+ # 获得new_dir中的图片数
78
+ num_pic = len(os.listdir(new_dir))
79
+
80
+ dst = os.path.join(os.path.abspath(new_dir),
81
+ dirname + 'C1T0001F' + str(num_pic + 1) + '.jpg')
82
+ # 处理后的格式也为jpg格式的,当然这里可以改成png格式 C1T0001F见mars.py filenames 相机ID,跟踪指数
83
+ # dst = os.path.join(os.path.abspath(self.path), '0000' + format(str(i), '0>3s') + '.jpg') 这种情况下的命名格式为0000000.jpg形式,可以自主定义想要的格式
84
+ try:
85
+ copyfile(src, dst) #os.rename(src, dst)
86
+ print ('converting %s to %s ...' % (src, dst))
87
+ i = i + 1
88
+ except:
89
+ continue
90
+ print ('total %d to rename & converted %d jpgs' % (total_num, i))
91
+
92
+ def split(self):
93
+ #---------------------------------------
94
+ #train_test
95
+ images_path = PATH_ALL_IMAGES
96
+ train_save_path = PATH_TRAIN
97
+ test_save_path = PATH_TEST
98
+ if not os.path.isdir(train_save_path):
99
+ os.mkdir(train_save_path)
100
+ os.mkdir(test_save_path)
101
+
102
+ for _, dirs, _ in os.walk(images_path, topdown=True):
103
+ for i, dir in enumerate(dirs):
104
+ for root, _, files in os.walk(images_path + '/' + dir, topdown=True):
105
+ for j, file in enumerate(files):
106
+ if(j==0): # test dataset;每个车辆的第一幅图片
107
+ print("序号:%s 文件夹: %s 图片:%s ��为测试集" % (i + 1, root, file))
108
+ src_path = root + '/' + file
109
+ dst_dir = test_save_path + '/' + dir
110
+ if not os.path.isdir(dst_dir):
111
+ os.mkdir(dst_dir)
112
+ dst_path = dst_dir + '/' + file
113
+ move(src_path, dst_path)
114
+ else:
115
+ src_path = root + '/' + file
116
+ dst_dir = train_save_path + '/' + dir
117
+ if not os.path.isdir(dst_dir):
118
+ os.mkdir(dst_dir)
119
+ dst_path = dst_dir + '/' + file
120
+ move(src_path, dst_path)
121
+ rmtree(PATH_ALL_IMAGES)
122
+
123
+ if __name__ == '__main__':
124
+ demo = BatchRename()
125
+ demo.resize()
126
+ demo.rename()
127
+ demo.split()
128
+
129
+
deep_sort/deep_sort/deep/prepare_person.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+
4
+ # You only need to change this line to your dataset download path
5
+ download_path = './Market-1501-v15.09.15'
6
+
7
+ if not os.path.isdir(download_path):
8
+ print('please change the download_path')
9
+
10
+ save_path = download_path + '/pytorch'
11
+ if not os.path.isdir(save_path):
12
+ os.mkdir(save_path)
13
+ #-----------------------------------------
14
+ #query
15
+ query_path = download_path + '/query'
16
+ query_save_path = download_path + '/pytorch/query'
17
+ if not os.path.isdir(query_save_path):
18
+ os.mkdir(query_save_path)
19
+
20
+ for root, dirs, files in os.walk(query_path, topdown=True):
21
+ for name in files:
22
+ if not name[-3:]=='jpg':
23
+ continue
24
+ ID = name.split('_')
25
+ src_path = query_path + '/' + name
26
+ dst_path = query_save_path + '/' + ID[0]
27
+ if not os.path.isdir(dst_path):
28
+ os.mkdir(dst_path)
29
+ copyfile(src_path, dst_path + '/' + name)
30
+
31
+ #-----------------------------------------
32
+ #multi-query
33
+ query_path = download_path + '/gt_bbox'
34
+ # for dukemtmc-reid, we do not need multi-query
35
+ if os.path.isdir(query_path):
36
+ query_save_path = download_path + '/pytorch/multi-query'
37
+ if not os.path.isdir(query_save_path):
38
+ os.mkdir(query_save_path)
39
+
40
+ for root, dirs, files in os.walk(query_path, topdown=True):
41
+ for name in files:
42
+ if not name[-3:]=='jpg':
43
+ continue
44
+ ID = name.split('_')
45
+ src_path = query_path + '/' + name
46
+ dst_path = query_save_path + '/' + ID[0]
47
+ if not os.path.isdir(dst_path):
48
+ os.mkdir(dst_path)
49
+ copyfile(src_path, dst_path + '/' + name)
50
+
51
+ #-----------------------------------------
52
+ #gallery
53
+ gallery_path = download_path + '/bounding_box_test'
54
+ gallery_save_path = download_path + '/pytorch/gallery'
55
+ if not os.path.isdir(gallery_save_path):
56
+ os.mkdir(gallery_save_path)
57
+
58
+ for root, dirs, files in os.walk(gallery_path, topdown=True):
59
+ for name in files:
60
+ if not name[-3:]=='jpg':
61
+ continue
62
+ ID = name.split('_')
63
+ src_path = gallery_path + '/' + name
64
+ dst_path = gallery_save_path + '/' + ID[0]
65
+ if not os.path.isdir(dst_path):
66
+ os.mkdir(dst_path)
67
+ copyfile(src_path, dst_path + '/' + name)
68
+
69
+ #---------------------------------------
70
+ #train_all
71
+ train_path = download_path + '/bounding_box_train'
72
+ train_save_path = download_path + '/pytorch/train_all'
73
+ if not os.path.isdir(train_save_path):
74
+ os.mkdir(train_save_path)
75
+
76
+ for root, dirs, files in os.walk(train_path, topdown=True):
77
+ for name in files:
78
+ if not name[-3:]=='jpg':
79
+ continue
80
+ ID = name.split('_')
81
+ src_path = train_path + '/' + name
82
+ dst_path = train_save_path + '/' + ID[0]
83
+ if not os.path.isdir(dst_path):
84
+ os.mkdir(dst_path)
85
+ copyfile(src_path, dst_path + '/' + name)
86
+
87
+
88
+ #---------------------------------------
89
+ #train_val
90
+ train_path = download_path + '/bounding_box_train'
91
+ train_save_path = download_path + '/pytorch/train'
92
+ val_save_path = download_path + '/pytorch/test'
93
+ if not os.path.isdir(train_save_path):
94
+ os.mkdir(train_save_path)
95
+ os.mkdir(val_save_path)
96
+
97
+ for root, dirs, files in os.walk(train_path, topdown=True):
98
+ for name in files:
99
+ if not name[-3:]=='jpg':
100
+ continue
101
+ ID = name.split('_')
102
+ src_path = train_path + '/' + name
103
+ dst_path = train_save_path + '/' + ID[0]
104
+ if not os.path.isdir(dst_path):
105
+ os.mkdir(dst_path)
106
+ dst_path = val_save_path + '/' + ID[0] #first image is used as val image
107
+ os.mkdir(dst_path)
108
+ copyfile(src_path, dst_path + '/' + name)
deep_sort/deep_sort/deep/test.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+ import torchvision
4
+
5
+ import argparse
6
+ import os
7
+
8
+ from model import Net
9
+
10
+ parser = argparse.ArgumentParser(description="Train on market1501")
11
+ parser.add_argument("--data-dir",default='data',type=str)
12
+ parser.add_argument("--no-cuda",action="store_true")
13
+ parser.add_argument("--gpu-id",default=0,type=int)
14
+ args = parser.parse_args()
15
+
16
+ # device
17
+ device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
18
+ if torch.cuda.is_available() and not args.no_cuda:
19
+ cudnn.benchmark = True
20
+
21
+ # data loader
22
+ root = args.data_dir
23
+ query_dir = os.path.join(root,"query")
24
+ gallery_dir = os.path.join(root,"gallery")
25
+ transform = torchvision.transforms.Compose([
26
+ torchvision.transforms.Resize((128,64)),
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
29
+ ])
30
+ queryloader = torch.utils.data.DataLoader(
31
+ torchvision.datasets.ImageFolder(query_dir, transform=transform),
32
+ batch_size=64, shuffle=False
33
+ )
34
+ galleryloader = torch.utils.data.DataLoader(
35
+ torchvision.datasets.ImageFolder(gallery_dir, transform=transform),
36
+ batch_size=64, shuffle=False
37
+ )
38
+
39
+ # net definition
40
+ net = Net(reid=True)
41
+ assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
42
+ print('Loading from checkpoint/ckpt.t7')
43
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
44
+ net_dict = checkpoint['net_dict']
45
+ net.load_state_dict(net_dict, strict=False)
46
+ net.eval()
47
+ net.to(device)
48
+
49
+ # compute features
50
+ query_features = torch.tensor([]).float()
51
+ query_labels = torch.tensor([]).long()
52
+ gallery_features = torch.tensor([]).float()
53
+ gallery_labels = torch.tensor([]).long()
54
+
55
+ with torch.no_grad():
56
+ for idx,(inputs,labels) in enumerate(queryloader):
57
+ inputs = inputs.to(device)
58
+ features = net(inputs).cpu()
59
+ query_features = torch.cat((query_features, features), dim=0)
60
+ query_labels = torch.cat((query_labels, labels))
61
+
62
+ for idx,(inputs,labels) in enumerate(galleryloader):
63
+ inputs = inputs.to(device)
64
+ features = net(inputs).cpu()
65
+ gallery_features = torch.cat((gallery_features, features), dim=0)
66
+ gallery_labels = torch.cat((gallery_labels, labels))
67
+
68
+ gallery_labels -= 2
69
+
70
+ # save features
71
+ features = {
72
+ "qf": query_features,
73
+ "ql": query_labels,
74
+ "gf": gallery_features,
75
+ "gl": gallery_labels
76
+ }
77
+ torch.save(features,"features.pth")
deep_sort/deep_sort/deep/train.jpg ADDED
deep_sort/deep_sort/deep/train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torchvision
10
+
11
+ from model import Net
12
+
13
+ parser = argparse.ArgumentParser(description="Train on market1501")
14
+ parser.add_argument("--data-dir",default='data',type=str)
15
+ parser.add_argument("--no-cuda",action="store_true")
16
+ parser.add_argument("--gpu-id",default=0,type=int)
17
+ parser.add_argument("--lr",default=0.1, type=float)
18
+ parser.add_argument("--interval",'-i',default=20,type=int)
19
+ parser.add_argument('--resume', '-r',action='store_true')
20
+ args = parser.parse_args()
21
+
22
+ # device
23
+ device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
24
+ if torch.cuda.is_available() and not args.no_cuda:
25
+ cudnn.benchmark = True
26
+
27
+ # data loading
28
+ root = args.data_dir
29
+ train_dir = os.path.join(root,"train")
30
+ test_dir = os.path.join(root,"test")
31
+
32
+ transform_train = torchvision.transforms.Compose([
33
+ torchvision.transforms.RandomCrop((128,64),padding=4),
34
+ torchvision.transforms.RandomHorizontalFlip(),
35
+ torchvision.transforms.ToTensor(),
36
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
+ ])
38
+ transform_test = torchvision.transforms.Compose([
39
+ torchvision.transforms.Resize((128,64)),
40
+ torchvision.transforms.ToTensor(),
41
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
+ trainloader = torch.utils.data.DataLoader(
44
+ torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
45
+ batch_size=64,shuffle=True
46
+ )
47
+ testloader = torch.utils.data.DataLoader(
48
+ torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
49
+ batch_size=64,shuffle=True
50
+ )
51
+ num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
52
+ print("num_classes = %s" %num_classes)
53
+
54
+ # net definition
55
+ start_epoch = 0
56
+ net = Net(num_classes=num_classes)
57
+ if args.resume:
58
+ assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
59
+ print('Loading from checkpoint/ckpt.t7')
60
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
61
+ # import ipdb; ipdb.set_trace()
62
+ net_dict = checkpoint['net_dict']
63
+ net.load_state_dict(net_dict)
64
+ best_acc = checkpoint['acc']
65
+ start_epoch = checkpoint['epoch']
66
+ net.to(device)
67
+
68
+ # loss and optimizer
69
+ criterion = torch.nn.CrossEntropyLoss()
70
+ optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
71
+ best_acc = 0.
72
+
73
+ # train function for each epoch
74
+ def train(epoch):
75
+ print("\nEpoch : %d"%(epoch+1))
76
+ net.train()
77
+ training_loss = 0.
78
+ train_loss = 0.
79
+ correct = 0
80
+ total = 0
81
+ interval = args.interval
82
+ start = time.time()
83
+ for idx, (inputs, labels) in enumerate(trainloader):
84
+ # forward
85
+ inputs,labels = inputs.to(device),labels.to(device)
86
+ outputs = net(inputs)
87
+ loss = criterion(outputs, labels)
88
+
89
+ # backward
90
+ optimizer.zero_grad()
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ # accumurating
95
+ training_loss += loss.item()
96
+ train_loss += loss.item()
97
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
98
+ total += labels.size(0)
99
+
100
+ # print
101
+ if (idx+1)%interval == 0:
102
+ end = time.time()
103
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
104
+ 100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total
105
+ ))
106
+ training_loss = 0.
107
+ start = time.time()
108
+
109
+ return train_loss/len(trainloader), 1.- correct/total
110
+
111
+ def test(epoch):
112
+ global best_acc
113
+ net.eval()
114
+ test_loss = 0.
115
+ correct = 0
116
+ total = 0
117
+ start = time.time()
118
+ with torch.no_grad():
119
+ for idx, (inputs, labels) in enumerate(testloader):
120
+ inputs, labels = inputs.to(device), labels.to(device)
121
+ outputs = net(inputs)
122
+ loss = criterion(outputs, labels)
123
+
124
+ test_loss += loss.item()
125
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
126
+ total += labels.size(0)
127
+
128
+ print("Testing ...")
129
+ end = time.time()
130
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
131
+ 100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total
132
+ ))
133
+
134
+ # saving checkpoint
135
+ acc = 100.*correct/total
136
+ if acc > best_acc:
137
+ best_acc = acc
138
+ print("Saving parameters to checkpoint/ckpt.t7")
139
+ checkpoint = {
140
+ 'net_dict':net.state_dict(),
141
+ 'acc':acc,
142
+ 'epoch':epoch,
143
+ }
144
+ if not os.path.isdir('checkpoint'):
145
+ os.mkdir('checkpoint')
146
+ torch.save(checkpoint, './checkpoint/ckpt.t7')
147
+
148
+ return test_loss/len(testloader), 1.- correct/total
149
+
150
+ # plot figure
151
+ x_epoch = []
152
+ record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]}
153
+ fig = plt.figure()
154
+ ax0 = fig.add_subplot(121, title="loss")
155
+ ax1 = fig.add_subplot(122, title="top1err")
156
+ def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
157
+ global record
158
+ record['train_loss'].append(train_loss)
159
+ record['train_err'].append(train_err)
160
+ record['test_loss'].append(test_loss)
161
+ record['test_err'].append(test_err)
162
+
163
+ x_epoch.append(epoch)
164
+ ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
165
+ ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
166
+ ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
167
+ ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
168
+ if epoch == 0:
169
+ ax0.legend()
170
+ ax1.legend()
171
+ fig.savefig("train.jpg")
172
+
173
+ # lr decay
174
+ def lr_decay():
175
+ global optimizer
176
+ for params in optimizer.param_groups:
177
+ params['lr'] *= 0.1
178
+ lr = params['lr']
179
+ print("Learning rate adjusted to {}".format(lr))
180
+
181
+ def main():
182
+ total_epoches = 40
183
+ for epoch in range(start_epoch, start_epoch+total_epoches):
184
+ train_loss, train_err = train(epoch)
185
+ test_loss, test_err = test(epoch)
186
+ draw_curve(epoch, train_loss, train_err, test_loss, test_err)
187
+ if (epoch+1)%(total_epoches//2)==0:
188
+ lr_decay()
189
+
190
+
191
+ if __name__ == '__main__':
192
+ main()
deep_sort/deep_sort/deep_sort.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .deep.feature_extractor import Extractor
5
+ from .sort.nn_matching import NearestNeighborDistanceMetric
6
+ from .sort.preprocessing import non_max_suppression
7
+ from .sort.detection import Detection
8
+ from .sort.tracker import Tracker
9
+
10
+
11
+ __all__ = ['DeepSort'] # __all__ 提供了暴露接口用的”白名单“
12
+
13
+
14
+ class DeepSort(object):
15
+ def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
16
+ self.min_confidence = min_confidence # 检测结果置信度阈值
17
+ self.nms_max_overlap = nms_max_overlap # 非极大抑制阈值,设置为1代表不进行抑制
18
+
19
+ self.extractor = Extractor(model_path, use_cuda=use_cuda) # 用于提取一个batch图片对应的特征
20
+
21
+ max_cosine_distance = max_dist # 最大余弦距离,用于级联匹配,如果大于该阈值,则忽略
22
+ nn_budget = 100 # 每个类别gallery最多的外观描述子的个数,如果超过,删除旧的
23
+ # NearestNeighborDistanceMetric 最近邻距离度量
24
+ # 对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
25
+ # 由距离度量方法构造一个 Tracker。
26
+ # 第一个参数可选'cosine' or 'euclidean'
27
+ metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
28
+ self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
29
+
30
+ def update(self, bbox_xywh, confidences, ori_img):
31
+ self.height, self.width = ori_img.shape[:2]
32
+ # generate detections
33
+ # 从原图中抠取bbox对应图片并计算得到相应的特征
34
+ features = self._get_features(bbox_xywh, ori_img)
35
+ bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
36
+ # 筛选掉小于min_confidence的目标,并构造一个Detection对象构成的列表
37
+ detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence]
38
+
39
+ # run on non-maximum supression
40
+ boxes = np.array([d.tlwh for d in detections])
41
+ scores = np.array([d.confidence for d in detections])
42
+ indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
43
+ detections = [detections[i] for i in indices]
44
+
45
+ # update tracker
46
+ self.tracker.predict() # 将跟踪状态分布向前传播一步
47
+ self.tracker.update(detections) # 执行测量更新和跟踪管理
48
+
49
+ # output bbox identities
50
+ outputs = []
51
+ for track in self.tracker.tracks:
52
+ if not track.is_confirmed() or track.time_since_update > 1:
53
+ continue
54
+ box = track.to_tlwh()
55
+ x1,y1,x2,y2 = self._tlwh_to_xyxy(box)
56
+ track_id = track.track_id
57
+ outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int16))
58
+ if len(outputs) > 0:
59
+ outputs = np.stack(outputs,axis=0)
60
+ return outputs
61
+
62
+
63
+ """
64
+ TODO:
65
+ Convert bbox from xc_yc_w_h to xtl_ytl_w_h
66
+ Thanks JieChen91@github.com for reporting this bug!
67
+ """
68
+ #将bbox的[x,y,w,h] 转换成[t,l,w,h]
69
+ @staticmethod
70
+ def _xywh_to_tlwh(bbox_xywh):
71
+ if isinstance(bbox_xywh, np.ndarray):
72
+ bbox_tlwh = bbox_xywh.copy()
73
+ elif isinstance(bbox_xywh, torch.Tensor):
74
+ bbox_tlwh = bbox_xywh.clone()
75
+ bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2.
76
+ bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2.
77
+ return bbox_tlwh
78
+
79
+ #将bbox的[x,y,w,h] 转换成[x1,y1,x2,y2]
80
+ #某些数据集例如 pascal_voc 的标注方式是采用[x,y,w,h]
81
+ """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
82
+ def _xywh_to_xyxy(self, bbox_xywh):
83
+ x,y,w,h = bbox_xywh
84
+ x1 = max(int(x-w/2),0)
85
+ x2 = min(int(x+w/2),self.width-1)
86
+ y1 = max(int(y-h/2),0)
87
+ y2 = min(int(y+h/2),self.height-1)
88
+ return x1,y1,x2,y2
89
+
90
+ def _tlwh_to_xyxy(self, bbox_tlwh):
91
+ """
92
+ TODO:
93
+ Convert bbox from xtl_ytl_w_h to xc_yc_w_h
94
+ Thanks JieChen91@github.com for reporting this bug!
95
+ """
96
+ x,y,w,h = bbox_tlwh
97
+ x1 = max(int(x),0)
98
+ x2 = min(int(x+w),self.width-1)
99
+ y1 = max(int(y),0)
100
+ y2 = min(int(y+h),self.height-1)
101
+ return x1,y1,x2,y2
102
+
103
+ def _xyxy_to_tlwh(self, bbox_xyxy):
104
+ x1,y1,x2,y2 = bbox_xyxy
105
+
106
+ t = x1
107
+ l = y1
108
+ w = int(x2-x1)
109
+ h = int(y2-y1)
110
+ return t,l,w,h
111
+
112
+ # 获取抠图部分的特征
113
+ def _get_features(self, bbox_xywh, ori_img):
114
+ im_crops = []
115
+ for box in bbox_xywh:
116
+ x1,y1,x2,y2 = self._xywh_to_xyxy(box)
117
+ im = ori_img[y1:y2,x1:x2] # 抠图部分
118
+ im_crops.append(im)
119
+ if im_crops:
120
+ features = self.extractor(im_crops) # 对抠图部分提取特征
121
+ else:
122
+ features = np.array([])
123
+ return features
124
+
125
+
deep_sort/deep_sort/sort/__init__.py ADDED
File without changes
deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
deep_sort/deep_sort/sort/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (146 Bytes). View file
 
deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/detection.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc ADDED
Binary file (2.92 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc ADDED
Binary file (7.93 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc ADDED
Binary file (7.93 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc ADDED
Binary file (8.17 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc ADDED
Binary file (7.43 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc ADDED
Binary file (7.44 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc ADDED
Binary file (6.87 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/track.cpython-38.pyc ADDED
Binary file (6.87 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc ADDED
Binary file (5.69 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/tracker.cpython-38.pyc ADDED
Binary file (5.76 kB). View file
 
deep_sort/deep_sort/sort/detection.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+
4
+
5
+ class Detection(object):
6
+ """
7
+ This class represents a bounding box detection in a single image.
8
+
9
+ Parameters
10
+ ----------
11
+ tlwh : array_like
12
+ Bounding box in format `(top left x, top left y, width, height)`.
13
+ confidence : float
14
+ Detector confidence score.
15
+ feature : array_like
16
+ A feature vector that describes the object contained in this image.
17
+
18
+ Attributes
19
+ ----------
20
+ tlwh : ndarray
21
+ Bounding box in format `(top left x, top left y, width, height)`.
22
+ confidence : ndarray
23
+ Detector confidence score.
24
+ feature : ndarray | NoneType
25
+ A feature vector that describes the object contained in this image.
26
+
27
+ """
28
+
29
+ def __init__(self, tlwh, confidence, feature):
30
+ self.tlwh = np.asarray(tlwh, dtype=np.float32)
31
+ self.confidence = float(confidence)
32
+ self.feature = np.asarray(feature, dtype=np.float32)
33
+
34
+ def to_tlbr(self):
35
+ """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
36
+ `(top left, bottom right)`.
37
+ """
38
+ ret = self.tlwh.copy()
39
+ ret[2:] += ret[:2]
40
+ return ret
41
+
42
+ def to_xyah(self):
43
+ """Convert bounding box to format `(center x, center y, aspect ratio,
44
+ height)`, where the aspect ratio is `width / height`.
45
+ """
46
+ ret = self.tlwh.copy()
47
+ ret[:2] += ret[2:] / 2
48
+ ret[2] /= ret[3]
49
+ return ret
deep_sort/deep_sort/sort/iou_matching.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ from __future__ import absolute_import
3
+ import numpy as np
4
+ from . import linear_assignment
5
+
6
+ #计算两个框的IOU
7
+ def iou(bbox, candidates):
8
+ """Computer intersection over union.
9
+
10
+ Parameters
11
+ ----------
12
+ bbox : ndarray
13
+ A bounding box in format `(top left x, top left y, width, height)`.
14
+ candidates : ndarray
15
+ A matrix of candidate bounding boxes (one per row) in the same format
16
+ as `bbox`.
17
+
18
+ Returns
19
+ -------
20
+ ndarray
21
+ The intersection over union in [0, 1] between the `bbox` and each
22
+ candidate. A higher score means a larger fraction of the `bbox` is
23
+ occluded by the candidate.
24
+
25
+ """
26
+ bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
27
+ candidates_tl = candidates[:, :2]
28
+ candidates_br = candidates[:, :2] + candidates[:, 2:]
29
+
30
+ # np.c_ Translates slice objects to concatenation along the second axis.
31
+ tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
32
+ np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
33
+ br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
34
+ np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
35
+ wh = np.maximum(0., br - tl)
36
+
37
+ area_intersection = wh.prod(axis=1)
38
+ area_bbox = bbox[2:].prod()
39
+ area_candidates = candidates[:, 2:].prod(axis=1)
40
+ return area_intersection / (area_bbox + area_candidates - area_intersection)
41
+
42
+ # 计算tracks和detections之间的IOU距离成本矩阵
43
+ def iou_cost(tracks, detections, track_indices=None,
44
+ detection_indices=None):
45
+ """An intersection over union distance metric.
46
+
47
+ 用于计算tracks和detections之间的iou距离矩阵
48
+
49
+ Parameters
50
+ ----------
51
+ tracks : List[deep_sort.track.Track]
52
+ A list of tracks.
53
+ detections : List[deep_sort.detection.Detection]
54
+ A list of detections.
55
+ track_indices : Optional[List[int]]
56
+ A list of indices to tracks that should be matched. Defaults to
57
+ all `tracks`.
58
+ detection_indices : Optional[List[int]]
59
+ A list of indices to detections that should be matched. Defaults
60
+ to all `detections`.
61
+
62
+ Returns
63
+ -------
64
+ ndarray
65
+ Returns a cost matrix of shape
66
+ len(track_indices), len(detection_indices) where entry (i, j) is
67
+ `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
68
+
69
+ """
70
+ if track_indices is None:
71
+ track_indices = np.arange(len(tracks))
72
+ if detection_indices is None:
73
+ detection_indices = np.arange(len(detections))
74
+
75
+ cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
76
+ for row, track_idx in enumerate(track_indices):
77
+ if tracks[track_idx].time_since_update > 1:
78
+ cost_matrix[row, :] = linear_assignment.INFTY_COST
79
+ continue
80
+
81
+ bbox = tracks[track_idx].to_tlwh()
82
+ candidates = np.asarray([detections[i].tlwh for i in detection_indices])
83
+ cost_matrix[row, :] = 1. - iou(bbox, candidates)
84
+ return cost_matrix
deep_sort/deep_sort/sort/kalman_filter.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+ import scipy.linalg
4
+
5
+
6
+ """
7
+ Table for the 0.95 quantile of the chi-square distribution with N degrees of
8
+ freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
9
+ function and used as Mahalanobis gating threshold.
10
+ """
11
+ chi2inv95 = {
12
+ 1: 3.8415,
13
+ 2: 5.9915,
14
+ 3: 7.8147,
15
+ 4: 9.4877,
16
+ 5: 11.070,
17
+ 6: 12.592,
18
+ 7: 14.067,
19
+ 8: 15.507,
20
+ 9: 16.919}
21
+
22
+ '''
23
+ 卡尔曼滤波分为两个阶段:
24
+ (1) 预测track在下一时刻的位置,
25
+ (2) 基于detection来更新预测的位置。
26
+ '''
27
+ class KalmanFilter(object):
28
+ """
29
+ A simple Kalman filter for tracking bounding boxes in image space.
30
+
31
+ The 8-dimensional state space
32
+
33
+ x, y, a, h, vx, vy, va, vh
34
+
35
+ contains the bounding box center position (x, y), aspect ratio a, height h,
36
+ and their respective velocities.
37
+
38
+ Object motion follows a constant velocity model. The bounding box location
39
+ (x, y, a, h) is taken as direct observation of the state space (linear
40
+ observation model).
41
+
42
+ 对于每个轨迹,由一个 KalmanFilter 预测状态分布。每个轨迹记录自己的均值和方差作为滤波器输入。
43
+
44
+ 8维状态空间[x, y, a, h, vx, vy, va, vh]包含边界框中心位置(x, y),纵横比a,高度h和它们各自的速度。
45
+ 物体运动遵循恒速模型。 边界框位置(x, y, a, h)被视为状态空间的直接观察(线性观察模型)
46
+
47
+ """
48
+
49
+ def __init__(self):
50
+ ndim, dt = 4, 1.
51
+
52
+ # Create Kalman filter model matrices.
53
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
54
+ for i in range(ndim):
55
+ self._motion_mat[i, ndim + i] = dt
56
+ self._update_mat = np.eye(ndim, 2 * ndim)
57
+
58
+ # Motion and observation uncertainty are chosen relative to the current
59
+ # state estimate. These weights control the amount of uncertainty in
60
+ # the model. This is a bit hacky.
61
+ # 依据当前状态估计(高度)选择运动和观测不确定性。这些权重控制模型中的不确定性。
62
+ self._std_weight_position = 1. / 20
63
+ self._std_weight_velocity = 1. / 160
64
+
65
+ def initiate(self, measurement):
66
+ """Create track from unassociated measurement.
67
+
68
+ Parameters
69
+ ----------
70
+ measurement : ndarray
71
+ Bounding box coordinates (x, y, a, h) with center position (x, y),
72
+ aspect ratio a, and height h.
73
+
74
+ Returns
75
+ -------
76
+ (ndarray, ndarray)
77
+ Returns the mean vector (8 dimensional) and covariance matrix (8x8
78
+ dimensional) of the new track. Unobserved velocities are initialized
79
+ to 0 mean.
80
+
81
+ """
82
+
83
+
84
+ mean_pos = measurement
85
+ mean_vel = np.zeros_like(mean_pos)
86
+ # Translates slice objects to concatenation along the first axis
87
+ mean = np.r_[mean_pos, mean_vel]
88
+
89
+ # 由测量初始化均值向量(8维)和协方差矩阵(8x8维)
90
+ std = [
91
+ 2 * self._std_weight_position * measurement[3],
92
+ 2 * self._std_weight_position * measurement[3],
93
+ 1e-2,
94
+ 2 * self._std_weight_position * measurement[3],
95
+ 10 * self._std_weight_velocity * measurement[3],
96
+ 10 * self._std_weight_velocity * measurement[3],
97
+ 1e-5,
98
+ 10 * self._std_weight_velocity * measurement[3]]
99
+ covariance = np.diag(np.square(std))
100
+ return mean, covariance
101
+
102
+ def predict(self, mean, covariance):
103
+ """Run Kalman filter prediction step.
104
+
105
+ Parameters
106
+ ----------
107
+ mean : ndarray
108
+ The 8 dimensional mean vector of the object state at the previous
109
+ time step.
110
+ covariance : ndarray
111
+ The 8x8 dimensional covariance matrix of the object state at the
112
+ previous time step.
113
+
114
+ Returns
115
+ -------
116
+ (ndarray, ndarray)
117
+ Returns the mean vector and covariance matrix of the predicted
118
+ state. Unobserved velocities are initialized to 0 mean.
119
+
120
+ """
121
+ #卡尔曼滤波器由目标上一时刻的均值和协方差进行预测。
122
+ std_pos = [
123
+ self._std_weight_position * mean[3],
124
+ self._std_weight_position * mean[3],
125
+ 1e-2,
126
+ self._std_weight_position * mean[3]]
127
+ std_vel = [
128
+ self._std_weight_velocity * mean[3],
129
+ self._std_weight_velocity * mean[3],
130
+ 1e-5,
131
+ self._std_weight_velocity * mean[3]]
132
+
133
+ # 初始化噪声矩阵Q;np.r_ 按列连接两个矩阵
134
+ # motion_cov是过程噪声 W_k的 协方差矩阵Qk
135
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
136
+
137
+ # Update time state x' = Fx (1)
138
+ # x为track在t-1时刻的均值,F称为状态转移矩阵,该公式预测t时刻的x'
139
+ # self._motion_mat为F_k是作用在 x_{k-1}上的状态变换模型
140
+ mean = np.dot(self._motion_mat, mean)
141
+ # Calculate error covariance P' = FPF^T+Q (2)
142
+ # P为track在t-1时刻的协方差,Q为系统的噪声矩阵,代表整个系统的可靠程度,一般初始化为很小的值,
143
+ # 该公式预测t时刻的P'
144
+ # covariance为P_{k|k} ,后验估计误差协方差矩阵,度量估计值的精确程度
145
+ covariance = np.linalg.multi_dot((
146
+ self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
147
+
148
+ return mean, covariance
149
+
150
+ def project(self, mean, covariance):
151
+ """Project state distribution to measurement space.
152
+ 投影状态分布到测量空间
153
+
154
+ Parameters
155
+ ----------
156
+ mean : ndarray
157
+ The state's mean vector (8 dimensional array).
158
+ covariance : ndarray
159
+ The state's covariance matrix (8x8 dimensional).
160
+
161
+ mean:ndarray,状态的平均向量(8维数组)。
162
+ covariance:ndarray,状态的协方差矩阵(8x8维)。
163
+
164
+ Returns
165
+ -------
166
+ (ndarray, ndarray)
167
+ Returns the projected mean and covariance matrix of the given state
168
+ estimate.
169
+
170
+ 返回(ndarray,ndarray),返回给定状态估计的投影平均值和协方差矩阵
171
+
172
+ """
173
+ # 在公式4中,R为检测器的噪声矩阵,它是一个4x4的对角矩阵,
174
+ # 对角线上的值分别为中心点两个坐标以及宽高的噪声,
175
+ # 以任意值初始化,一般设置宽高的噪声大于中心点的噪声,
176
+ # 该公式先将协方差矩阵P'映射到检测空间,然后再加上噪声矩阵R;
177
+ std = [
178
+ self._std_weight_position * mean[3],
179
+ self._std_weight_position * mean[3],
180
+ 1e-1,
181
+ self._std_weight_position * mean[3]]
182
+
183
+ # R为测量过程中噪声的协方差;初始化噪声矩阵R
184
+ innovation_cov = np.diag(np.square(std))
185
+
186
+ # 将均值向量映射到检测空间,即 Hx'
187
+ mean = np.dot(self._update_mat, mean)
188
+ # 将协方差矩阵映射到检测空间,即 HP'H^T
189
+ covariance = np.linalg.multi_dot((
190
+ self._update_mat, covariance, self._update_mat.T))
191
+ return mean, covariance + innovation_cov # 公式(4)
192
+
193
+ def update(self, mean, covariance, measurement):
194
+ """Run Kalman filter correction step.
195
+ 通过估计值和观测值估计最新结果
196
+
197
+ Parameters
198
+ ----------
199
+ mean : ndarray
200
+ The predicted state's mean vector (8 dimensional).
201
+ covariance : ndarray
202
+ The state's covariance matrix (8x8 dimensional).
203
+ measurement : ndarray
204
+ The 4 dimensional measurement vector (x, y, a, h), where (x, y)
205
+ is the center position, a the aspect ratio, and h the height of the
206
+ bounding box.
207
+
208
+ Returns
209
+ -------
210
+ (ndarray, ndarray)
211
+ Returns the measurement-corrected state distribution.
212
+
213
+ """
214
+ # 将均值和协方差映射到检测空间,得到 Hx'和S
215
+ projected_mean, projected_cov = self.project(mean, covariance)
216
+
217
+ # 矩阵分解
218
+ chol_factor, lower = scipy.linalg.cho_factor(
219
+ projected_cov, lower=True, check_finite=False)
220
+ # 计算卡尔曼增益K;相当于求解公式(5)
221
+ # 公式5计算卡尔曼增益K,卡尔曼增益用于估计误差的重要程度
222
+ # 求解卡尔曼滤波增益K 用到了cholesky矩阵分解加快求解;
223
+ # 公式5的右边有一个S的逆,如果S矩阵很大,S的逆求解消耗时间太大,
224
+ # 所以代码中把公式两边同时乘上S,右边的S*S的逆变成了单位矩阵,转化成AX=B形式求解。
225
+ kalman_gain = scipy.linalg.cho_solve(
226
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
227
+ check_finite=False).T
228
+ # y = z - Hx' (3)
229
+ # 在公式3中,z为detection的均值向量,不包含速度变化值,即z=[cx, cy, r, h],
230
+ # H称为测量矩阵,它将track的均值向量x'映射到检测空间,该公式计算detection和track的均值误差
231
+ innovation = measurement - projected_mean
232
+
233
+ # 更新后的均值向量 x = x' + Ky (6)
234
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
235
+ # 更新后的协方差矩阵 P = (I - KH)P' (7)
236
+ new_covariance = covariance - np.linalg.multi_dot((
237
+ kalman_gain, projected_cov, kalman_gain.T))
238
+ return new_mean, new_covariance
239
+
240
+ def gating_distance(self, mean, covariance, measurements,
241
+ only_position=False):
242
+ """Compute gating distance between state distribution and measurements.
243
+
244
+ A suitable distance threshold can be obtained from `chi2inv95`. If
245
+ `only_position` is False, the chi-square distribution has 4 degrees of
246
+ freedom, otherwise 2.
247
+
248
+ Parameters
249
+ ----------
250
+ mean : ndarray
251
+ Mean vector over the state distribution (8 dimensional).
252
+ 状态分布上的平均向量(8维)
253
+ covariance : ndarray
254
+ Covariance of the state distribution (8x8 dimensional).
255
+ 状态分布的协方差(8x8维)
256
+ measurements : ndarray
257
+ An Nx4 dimensional matrix of N measurements, each in
258
+ format (x, y, a, h) where (x, y) is the bounding box center
259
+ position, a the aspect ratio, and h the height.
260
+ N 个测量的 N×4维矩阵,每个矩阵的格式为(x,y,a,h),其中(x,y)是边界框中心位置,宽高比和h高度。
261
+ only_position : Optional[bool]
262
+ If True, distance computation is done with respect to the bounding
263
+ box center position only.
264
+ 如果为True,则只计算盒子中心位置
265
+
266
+ Returns
267
+ -------
268
+ ndarray
269
+ Returns an array of length N, where the i-th element contains the
270
+ squared Mahalanobis distance between (mean, covariance) and
271
+ `measurements[i]`.
272
+ 返回一个长度为N的数组,其中第i个元素包含(mean,covariance)和measurements [i]之间的平方Mahalanobis距离
273
+
274
+ """
275
+ mean, covariance = self.project(mean, covariance)
276
+ if only_position:
277
+ mean, covariance = mean[:2], covariance[:2, :2]
278
+ measurements = measurements[:, :2]
279
+
280
+ cholesky_factor = np.linalg.cholesky(covariance)
281
+ d = measurements - mean
282
+ z = scipy.linalg.solve_triangular(
283
+ cholesky_factor, d.T, lower=True, check_finite=False,
284
+ overwrite_b=True)
285
+ squared_maha = np.sum(z * z, axis=0)
286
+ return squared_maha
deep_sort/deep_sort/sort/linear_assignment.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ from __future__ import absolute_import
3
+ import numpy as np
4
+ # The linear sum assignment problem is also known as minimum weight matching in bipartite graphs.
5
+ from scipy.optimize import linear_sum_assignment as linear_assignment
6
+ from . import kalman_filter
7
+
8
+
9
+ INFTY_COST = 1e+5
10
+
11
+ # min_cost_matching 使用匈牙利算法解决线性分配问题。
12
+ # 传入 门控余弦距离成本 或 iou cost
13
+ def min_cost_matching(
14
+ distance_metric, max_distance, tracks, detections, track_indices=None,
15
+ detection_indices=None):
16
+ """Solve linear assignment problem.
17
+
18
+ Parameters
19
+ ----------
20
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
21
+ The distance metric is given a list of tracks and detections as well as
22
+ a list of N track indices and M detection indices. The metric should
23
+ return the NxM dimensional cost matrix, where element (i, j) is the
24
+ association cost between the i-th track in the given track indices and
25
+ the j-th detection in the given detection_indices.
26
+ max_distance : float
27
+ Gating threshold. Associations with cost larger than this value are
28
+ disregarded.
29
+ tracks : List[track.Track]
30
+ A list of predicted tracks at the current time step.
31
+ detections : List[detection.Detection]
32
+ A list of detections at the current time step.
33
+ track_indices : List[int]
34
+ List of track indices that maps rows in `cost_matrix` to tracks in
35
+ `tracks` (see description above).
36
+ detection_indices : List[int]
37
+ List of detection indices that maps columns in `cost_matrix` to
38
+ detections in `detections` (see description above).
39
+
40
+ Returns
41
+ -------
42
+ (List[(int, int)], List[int], List[int])
43
+ Returns a tuple with the following three entries:
44
+ * A list of matched track and detection indices.
45
+ * A list of unmatched track indices.
46
+ * A list of unmatched detection indices.
47
+
48
+ """
49
+ if track_indices is None:
50
+ track_indices = np.arange(len(tracks))
51
+ if detection_indices is None:
52
+ detection_indices = np.arange(len(detections))
53
+
54
+ if len(detection_indices) == 0 or len(track_indices) == 0:
55
+ return [], track_indices, detection_indices # Nothing to match.
56
+
57
+ # 计算成本矩阵
58
+ cost_matrix = distance_metric(
59
+ tracks, detections, track_indices, detection_indices)
60
+ cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
61
+
62
+ # 执行匈牙利算法,得到指派成功的索引对,行索引为tracks的索引,列索引为detections的索引
63
+ row_indices, col_indices = linear_assignment(cost_matrix)
64
+
65
+ matches, unmatched_tracks, unmatched_detections = [], [], []
66
+ # 找出未匹配的detections
67
+ for col, detection_idx in enumerate(detection_indices):
68
+ if col not in col_indices:
69
+ unmatched_detections.append(detection_idx)
70
+ # 找出未匹配的tracks
71
+ for row, track_idx in enumerate(track_indices):
72
+ if row not in row_indices:
73
+ unmatched_tracks.append(track_idx)
74
+ # 遍历匹配的(track, detection)索引对
75
+ for row, col in zip(row_indices, col_indices):
76
+ track_idx = track_indices[row]
77
+ detection_idx = detection_indices[col]
78
+ # 如果相应的cost大于阈值max_distance,也视为未匹配成功
79
+ if cost_matrix[row, col] > max_distance:
80
+ unmatched_tracks.append(track_idx)
81
+ unmatched_detections.append(detection_idx)
82
+ else:
83
+ matches.append((track_idx, detection_idx))
84
+ return matches, unmatched_tracks, unmatched_detections
85
+
86
+
87
+ def matching_cascade(
88
+ distance_metric, max_distance, cascade_depth, tracks, detections,
89
+ track_indices=None, detection_indices=None):
90
+ """Run matching cascade.
91
+
92
+ Parameters
93
+ ----------
94
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
95
+ The distance metric is given a list of tracks and detections as well as
96
+ a list of N track indices and M detection indices. The metric should
97
+ return the NxM dimensional cost matrix, where element (i, j) is the
98
+ association cost between the i-th track in the given track indices and
99
+ the j-th detection in the given detection indices.
100
+ 距离度量:
101
+ 输入:一个轨迹和检测列表,以及一个N个轨迹索引和M个检测索引的列表。
102
+ 返回:NxM维的代价矩阵,其中元素(i,j)是给定轨迹索引中第i个轨迹与
103
+ 给定检测索引中第j个检测之间的关联成本。
104
+ max_distance : float
105
+ Gating threshold. Associations with cost larger than this value are
106
+ disregarded.
107
+ 门控阈值。成本大于此值的关联将被忽略。
108
+ cascade_depth: int
109
+ The cascade depth, should be se to the maximum track age.
110
+ 级联深度应设置为最大轨迹寿命。
111
+ tracks : List[track.Track]
112
+ A list of predicted tracks at the current time step.
113
+ 当前时间步的预测轨迹列表。
114
+ detections : List[detection.Detection]
115
+ A list of detections at the current time step.
116
+ 当前时间步的检测列表。
117
+ track_indices : Optional[List[int]]
118
+ List of track indices that maps rows in `cost_matrix` to tracks in
119
+ `tracks` (see description above). Defaults to all tracks.
120
+ 轨迹索引列表,用于将 cost_matrix中的行映射到tracks的
121
+ 轨迹(请参见上面的说明)。 默认为所有轨迹。
122
+ detection_indices : Optional[List[int]]
123
+ List of detection indices that maps columns in `cost_matrix` to
124
+ detections in `detections` (see description above). Defaults to all
125
+ detections.
126
+ 将 cost_matrix中的列映射到的检测索引列表
127
+ detections中的检测(请参见上面的说明)。 默认为全部检测。
128
+
129
+ Returns
130
+ -------
131
+ (List[(int, int)], List[int], List[int])
132
+ Returns a tuple with the following three entries:
133
+ * A list of matched track and detection indices.
134
+ * A list of unmatched track indices.
135
+ * A list of unmatched detection indices.
136
+
137
+ 返回包含以下三个条目的元组:
138
+
139
+ 匹配的跟踪和检测的索引列表,
140
+ 不匹配的轨迹索引的列表,
141
+ 未匹配的检测索引的列表。
142
+
143
+ """
144
+
145
+ # 分配track_indices和detection_indices两个列表
146
+ if track_indices is None:
147
+ track_indices = list(range(len(tracks)))
148
+ if detection_indices is None:
149
+ detection_indices = list(range(len(detections)))
150
+
151
+ # 初始化匹配集matches M ← ∅
152
+ # 未匹配检测集unmatched_detections U ← D
153
+ unmatched_detections = detection_indices
154
+ matches = []
155
+ # 由小到大依次对每个level的tracks做匹配
156
+ for level in range(cascade_depth):
157
+ # 如果没有detections,退出循环
158
+ if len(unmatched_detections) == 0: # No detections left
159
+ break
160
+
161
+ # 当前level的所有tracks索引
162
+ # 步骤6:Select tracks by age
163
+ track_indices_l = [
164
+ k for k in track_indices
165
+ if tracks[k].time_since_update == 1 + level
166
+ ]
167
+ # 如果当前level没有track,继续
168
+ if len(track_indices_l) == 0: # Nothing to match at this level
169
+ continue
170
+
171
+ # 步骤7:调用min_cost_matching函数进行匹配
172
+ matches_l, _, unmatched_detections = \
173
+ min_cost_matching(
174
+ distance_metric, max_distance, tracks, detections,
175
+ track_indices_l, unmatched_detections)
176
+ matches += matches_l # 步骤8
177
+ unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) # 步骤9
178
+ return matches, unmatched_tracks, unmatched_detections
179
+
180
+ '''
181
+ 门控成本矩阵:通过计算卡尔曼滤波的状态分布和测量值之间的距离对成本矩阵进行限制,
182
+ 成本矩阵中的距离是track和detection之间的外观相似度。
183
+ 如果一个轨迹要去匹配两个外观特征非常相似的 detection,很容易出错;
184
+ 分别让两个detection计算与这个轨迹的马氏距离,并使用一个阈值gating_threshold进行限制,
185
+ 就可以将马氏距离较远的那个detection区分开,从而减少错误的匹配。
186
+ '''
187
+ def gate_cost_matrix(
188
+ kf, cost_matrix, tracks, detections, track_indices, detection_indices,
189
+ gated_cost=INFTY_COST, only_position=False):
190
+ """Invalidate infeasible entries in cost matrix based on the state
191
+ distributions obtained by Kalman filtering.
192
+
193
+ Parameters
194
+ ----------
195
+ kf : The Kalman filter.
196
+ cost_matrix : ndarray
197
+ The NxM dimensional cost matrix, where N is the number of track indices
198
+ and M is the number of detection indices, such that entry (i, j) is the
199
+ association cost between `tracks[track_indices[i]]` and
200
+ `detections[detection_indices[j]]`.
201
+ tracks : List[track.Track]
202
+ A list of predicted tracks at the current time step.
203
+ detections : List[detection.Detection]
204
+ A list of detections at the current time step.
205
+ track_indices : List[int]
206
+ List of track indices that maps rows in `cost_matrix` to tracks in
207
+ `tracks` (see description above).
208
+ detection_indices : List[int]
209
+ List of detection indices that maps columns in `cost_matrix` to
210
+ detections in `detections` (see description above).
211
+ gated_cost : Optional[float]
212
+ Entries in the cost matrix corresponding to infeasible associations are
213
+ set this value. Defaults to a very large value.
214
+ 代价矩阵中与不可行关联相对应的条目设置此值。 默认为一个很大的值。
215
+ only_position : Optional[bool]
216
+ If True, only the x, y position of the state distribution is considered
217
+ during gating. Defaults to False.
218
+ 如果为True,则在门控期间仅考虑状态分布的x,y位置。默认为False。
219
+
220
+ Returns
221
+ -------
222
+ ndarray
223
+ Returns the modified cost matrix.
224
+
225
+ """
226
+ # 根据通过卡尔曼滤波获得的状态分布,使成本矩阵中的不可行条目无效。
227
+ gating_dim = 2 if only_position else 4 # 测量空间维度
228
+ # 马氏距离通过测算检测与平均轨迹位置的距离超过多少标准差来考虑状态估计的不确定性。
229
+ # 通过从逆chi^2分布计算95%置信区间的阈值,排除可能性小的关联。
230
+ # 四维测量空间对应的马氏阈值为9.4877
231
+ gating_threshold = kalman_filter.chi2inv95[gating_dim]
232
+ measurements = np.asarray(
233
+ [detections[i].to_xyah() for i in detection_indices])
234
+ for row, track_idx in enumerate(track_indices):
235
+ track = tracks[track_idx]
236
+ #KalmanFilter.gating_distance 计算状态分布和测量之间的选通距离
237
+ gating_distance = kf.gating_distance(
238
+ track.mean, track.covariance, measurements, only_position)
239
+ cost_matrix[row, gating_distance > gating_threshold] = gated_cost
240
+ return cost_matrix
deep_sort/deep_sort/sort/nn_matching.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+
4
+
5
+ def _pdist(a, b):
6
+ """Compute pair-wise squared distance between points in `a` and `b`.
7
+
8
+ Parameters
9
+ ----------
10
+ a : array_like
11
+ An NxM matrix of N samples of dimensionality M.
12
+ b : array_like
13
+ An LxM matrix of L samples of dimensionality M.
14
+
15
+ Returns
16
+ -------
17
+ ndarray
18
+ Returns a matrix of size len(a), len(b) such that element (i, j)
19
+ contains the squared distance between `a[i]` and `b[j]`.
20
+
21
+
22
+ 用于计算成对点之间的平方距离
23
+ a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
24
+ b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
25
+ 返回的是 NxL 的矩阵,比如 dist[i][j] 代表 a[i] 和 b[j] 之间的平方和距离
26
+ 参考:https://blog.csdn.net/frankzd/article/details/80251042
27
+
28
+ """
29
+ a, b = np.asarray(a), np.asarray(b)
30
+ if len(a) == 0 or len(b) == 0:
31
+ return np.zeros((len(a), len(b)))
32
+ a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
33
+ r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
34
+ r2 = np.clip(r2, 0., float(np.inf))
35
+ return r2
36
+
37
+
38
+ def _cosine_distance(a, b, data_is_normalized=False):
39
+ """Compute pair-wise cosine distance between points in `a` and `b`.
40
+
41
+ Parameters
42
+ ----------
43
+ a : array_like
44
+ An NxM matrix of N samples of dimensionality M.
45
+ b : array_like
46
+ An LxM matrix of L samples of dimensionality M.
47
+ data_is_normalized : Optional[bool]
48
+ If True, assumes rows in a and b are unit length vectors.
49
+ Otherwise, a and b are explicitly normalized to lenght 1.
50
+
51
+ Returns
52
+ -------
53
+ ndarray
54
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
55
+ contains the squared distance between `a[i]` and `b[j]`.
56
+
57
+ 用于计算成对点之间的余弦距离
58
+ a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
59
+ b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
60
+ 返回的是 NxL 的矩阵,比如 c[i][j] 代表 a[i] 和 b[j] 之间的余弦距离
61
+ 参考:
62
+ https://blog.csdn.net/u013749540/article/details/51813922
63
+
64
+
65
+ """
66
+ if not data_is_normalized:
67
+ # np.linalg.norm 求向量的范式,默认是 L2 范式
68
+ a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
69
+ b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
70
+ return 1. - np.dot(a, b.T) # 余弦距离 = 1 - 余弦相似度
71
+
72
+
73
+ def _nn_euclidean_distance(x, y):
74
+ """ Helper function for nearest neighbor distance metric (Euclidean).
75
+
76
+ Parameters
77
+ ----------
78
+ x : ndarray
79
+ A matrix of N row-vectors (sample points).
80
+ y : ndarray
81
+ A matrix of M row-vectors (query points).
82
+
83
+ Returns
84
+ -------
85
+ ndarray
86
+ A vector of length M that contains for each entry in `y` the
87
+ smallest Euclidean distance to a sample in `x`.
88
+
89
+ """
90
+ distances = _pdist(x, y)
91
+ return np.maximum(0.0, distances.min(axis=0))
92
+
93
+
94
+ def _nn_cosine_distance(x, y):
95
+ """ Helper function for nearest neighbor distance metric (cosine).
96
+
97
+ Parameters
98
+ ----------
99
+ x : ndarray
100
+ A matrix of N row-vectors (sample points).
101
+ y : ndarray
102
+ A matrix of M row-vectors (query points).
103
+
104
+ Returns
105
+ -------
106
+ ndarray
107
+ A vector of length M that contains for each entry in `y` the
108
+ smallest cosine distance to a sample in `x`.
109
+
110
+ """
111
+ distances = _cosine_distance(x, y)
112
+ return distances.min(axis=0)
113
+
114
+
115
+ class NearestNeighborDistanceMetric(object):
116
+ """
117
+ A nearest neighbor distance metric that, for each target, returns
118
+ the closest distance to any sample that has been observed so far.
119
+
120
+ 对于每个目标,返回最近邻居的距离度量, 即与到目前为止已观察到的任何样本的最接近距离。
121
+
122
+ Parameters
123
+ ----------
124
+ metric : str
125
+ Either "euclidean" or "cosine".
126
+ matching_threshold: float
127
+ The matching threshold. Samples with larger distance are considered an
128
+ invalid match.
129
+ 匹配阈值。 距离较大的样本对被认为是无效的匹配。
130
+ budget : Optional[int]
131
+ If not None, fix samples per class to at most this number. Removes
132
+ the oldest samples when the budget is reached.
133
+ 如果不是None,则将每个类别的样本最多固定为该数字。
134
+ 删除达到budget时最古老的样本。
135
+
136
+ Attributes
137
+ ----------
138
+ samples : Dict[int -> List[ndarray]]
139
+ A dictionary that maps from target identities to the list of samples
140
+ that have been observed so far.
141
+ 一个从目标ID映射到到目前为止已经观察到的样本列表的字典
142
+
143
+ """
144
+
145
+ def __init__(self, metric, matching_threshold, budget=None):
146
+
147
+
148
+ if metric == "euclidean":
149
+ self._metric = _nn_euclidean_distance # 欧式距离
150
+ elif metric == "cosine":
151
+ self._metric = _nn_cosine_distance # 余弦距离
152
+ else:
153
+ raise ValueError(
154
+ "Invalid metric; must be either 'euclidean' or 'cosine'")
155
+ self.matching_threshold = matching_threshold
156
+ self.budget = budget # budge用于控制 feature 的数目
157
+ self.samples = {}
158
+
159
+ def partial_fit(self, features, targets, active_targets):
160
+ """Update the distance metric with new data.
161
+ 用新的数据更新测量距离
162
+
163
+ Parameters
164
+ ----------
165
+ features : ndarray
166
+ An NxM matrix of N features of dimensionality M.
167
+ targets : ndarray
168
+ An integer array of associated target identities.
169
+ active_targets : List[int]
170
+ A list of targets that are currently present in the scene.
171
+ 传入特征列表及其对应id,partial_fit构造一个活跃目标的特征字典。
172
+
173
+ """
174
+ for feature, target in zip(features, targets):
175
+ # 对应目标下添加新的feature,更新feature集合
176
+ # samples字典 d: feature list}
177
+ self.samples.setdefault(target, []).append(feature)
178
+ if self.budget is not None:
179
+ # 只考虑budget个目标,超过直接忽略
180
+ self.samples[target] = self.samples[target][-self.budget:]
181
+
182
+ # 筛选激活的目标;samples是一个字典{id->feature list}
183
+ self.samples = {k: self.samples[k] for k in active_targets}
184
+
185
+ def distance(self, features, targets):
186
+ """Compute distance between features and targets.
187
+
188
+ Parameters
189
+ ----------
190
+ features : ndarray
191
+ An NxM matrix of N features of dimensionality M.
192
+ targets : List[int]
193
+ A list of targets to match the given `features` against.
194
+
195
+ Returns
196
+ -------
197
+ ndarray
198
+ Returns a cost matrix of shape len(targets), len(features), where
199
+ element (i, j) contains the closest squared distance between
200
+ `targets[i]` and `features[j]`.
201
+
202
+ 计算features和targets之间的距离,返回一个成本矩阵(代价矩阵)
203
+ """
204
+ cost_matrix = np.zeros((len(targets), len(features)))
205
+ for i, target in enumerate(targets):
206
+ cost_matrix[i, :] = self._metric(self.samples[target], features)
207
+ return cost_matrix