Atualli commited on
Commit
c36e146
1 Parent(s): 9b78f64

Upload 11 files

Browse files
yoloxdetect2/__pycache__/helpers.cpython-38.pyc ADDED
Binary file (2.69 kB). View file
 
yoloxdetect2/configs/__init__.py ADDED
File without changes
yoloxdetect2/configs/yolov3.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ import torch.nn as nn
8
+
9
+ from yolox.exp import Exp as MyExp
10
+
11
+
12
+ class Exp(MyExp):
13
+ def __init__(self):
14
+ super(Exp, self).__init__()
15
+ self.depth = 1.0
16
+ self.width = 1.0
17
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
18
+
19
+ def get_model(self, sublinear=False):
20
+ def init_yolo(M):
21
+ for m in M.modules():
22
+ if isinstance(m, nn.BatchNorm2d):
23
+ m.eps = 1e-3
24
+ m.momentum = 0.03
25
+ if "model" not in self.__dict__:
26
+ from yolox.models import YOLOX, YOLOFPN, YOLOXHead
27
+ backbone = YOLOFPN()
28
+ head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
29
+ self.model = YOLOX(backbone, head)
30
+ self.model.apply(init_yolo)
31
+ self.model.head.initialize_biases(1e-2)
32
+
33
+ return self.model
yoloxdetect2/configs/yolox_l.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ from yolox.exp import Exp as MyExp
8
+
9
+
10
+ class Exp(MyExp):
11
+ def __init__(self):
12
+ super(Exp, self).__init__()
13
+ self.depth = 1.0
14
+ self.width = 1.0
15
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
yoloxdetect2/configs/yolox_m.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ from yolox.exp import Exp as MyExp
8
+
9
+
10
+ class Exp(MyExp):
11
+ def __init__(self):
12
+ super(Exp, self).__init__()
13
+ self.depth = 0.67
14
+ self.width = 0.75
15
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
yoloxdetect2/configs/yolox_nano.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ import torch.nn as nn
8
+
9
+ from yolox.exp import Exp as MyExp
10
+
11
+
12
+ class Exp(MyExp):
13
+ def __init__(self):
14
+ super(Exp, self).__init__()
15
+ self.depth = 0.33
16
+ self.width = 0.25
17
+ self.input_size = (416, 416)
18
+ self.random_size = (10, 20)
19
+ self.mosaic_scale = (0.5, 1.5)
20
+ self.test_size = (416, 416)
21
+ self.mosaic_prob = 0.5
22
+ self.enable_mixup = False
23
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
24
+
25
+ def get_model(self, sublinear=False):
26
+
27
+ def init_yolo(M):
28
+ for m in M.modules():
29
+ if isinstance(m, nn.BatchNorm2d):
30
+ m.eps = 1e-3
31
+ m.momentum = 0.03
32
+ if "model" not in self.__dict__:
33
+ from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
34
+ in_channels = [256, 512, 1024]
35
+ # NANO model use depthwise = True, which is main difference.
36
+ backbone = YOLOPAFPN(
37
+ self.depth, self.width, in_channels=in_channels,
38
+ act=self.act, depthwise=True,
39
+ )
40
+ head = YOLOXHead(
41
+ self.num_classes, self.width, in_channels=in_channels,
42
+ act=self.act, depthwise=True
43
+ )
44
+ self.model = YOLOX(backbone, head)
45
+
46
+ self.model.apply(init_yolo)
47
+ self.model.head.initialize_biases(1e-2)
48
+ return self.model
yoloxdetect2/configs/yolox_s.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ from yolox.exp import Exp as MyExp
8
+
9
+
10
+ class Exp(MyExp):
11
+ def __init__(self):
12
+ super(Exp, self).__init__()
13
+ self.depth = 0.33
14
+ self.width = 0.50
15
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
yoloxdetect2/configs/yolox_tiny.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ from yolox.exp import Exp as MyExp
8
+
9
+
10
+ class Exp(MyExp):
11
+ def __init__(self):
12
+ super(Exp, self).__init__()
13
+ self.depth = 0.33
14
+ self.width = 0.375
15
+ self.input_size = (416, 416)
16
+ self.mosaic_scale = (0.5, 1.5)
17
+ self.random_size = (10, 20)
18
+ self.test_size = (416, 416)
19
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
20
+ self.enable_mixup = False
yoloxdetect2/configs/yolox_x.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # Copyright (c) Megvii, Inc. and its affiliates.
4
+
5
+ import os
6
+
7
+ from yolox.exp import Exp as MyExp
8
+
9
+
10
+ class Exp(MyExp):
11
+ def __init__(self):
12
+ super(Exp, self).__init__()
13
+ self.depth = 1.33
14
+ self.width = 1.25
15
+ self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
yoloxdetect2/helpers.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yoloxdetect.utils.downloads import attempt_download_from_hub, attempt_download
2
+ from yolox.data.datasets import COCO_CLASSES
3
+ from yolox.data.data_augment import preproc
4
+ from yolox.utils import postprocess, vis
5
+ import importlib
6
+ import torch
7
+ import cv2
8
+ import os
9
+
10
+
11
+ class YoloxDetector2:
12
+ def __init__(
13
+ self,
14
+ model_path: str,
15
+ config_path: str,
16
+ device: str = "cpu",
17
+ hf_model: bool = False,
18
+ ):
19
+
20
+ self.device = device
21
+ self.config_path = config_path
22
+ self.classes = COCO_CLASSES
23
+ self.conf = 0.3
24
+ self.iou = 0.45
25
+ self.show = False
26
+ self.save = True
27
+ self.torchyolo = False
28
+
29
+ if self.save:
30
+ self.save_path = 'output/result.jpg'
31
+
32
+ if hf_model:
33
+ self.model_path = attempt_download_from_hub(model_path)
34
+
35
+ else:
36
+ self.model_path = attempt_download(model_path)
37
+
38
+ self.load_model()
39
+
40
+
41
+ def load_model(self):
42
+ current_exp = importlib.import_module(self.config_path)
43
+ exp = current_exp.Exp()
44
+
45
+ model = exp.get_model()
46
+ model.to(self.device)
47
+ model.eval()
48
+ ckpt = torch.load(self.model_path, map_location=self.device)
49
+ model.load_state_dict(ckpt["model"])
50
+ self.model = model
51
+
52
+
53
+ def predict(self, image_path, image_size):
54
+ image = cv2.imread(image_path)
55
+ if image_size is not None:
56
+ ratio = min(image_size / image.shape[0], image_size / image.shape[1])
57
+ img, _ = preproc(image, input_size=(image_size, image_size))
58
+ img = torch.from_numpy(img).to(self.device).unsqueeze(0).float()
59
+ else:
60
+ manuel_size = 640
61
+ ratio = min(manuel_size / image.shape[0], manuel_size / image.shape[1])
62
+ img, _ = preproc(image, input_size=(manuel_size, manuel_size))
63
+ img = torch.from_numpy(img).to(self.device).unsqueeze(0).float()
64
+
65
+ prediction_result = self.model(img)
66
+ original_predictions = postprocess(
67
+ prediction=prediction_result,
68
+ num_classes= len(COCO_CLASSES),
69
+ conf_thre=self.conf,
70
+ nms_thre=self.iou)[0]
71
+
72
+ if original_predictions is None :
73
+ return None
74
+ output = original_predictions.cpu()
75
+ bboxes = output[:, 0:4]
76
+ bboxes /= ratio
77
+ cls = output[:, 6]
78
+ scores = output[:, 4] * output[:, 5]
79
+ if self.torchyolo is False:
80
+ vis_res = vis(
81
+ image,
82
+ bboxes,
83
+ scores,
84
+ cls,
85
+ self.conf,
86
+ COCO_CLASSES,
87
+ )
88
+ if self.show:
89
+ cv2.imshow("result", vis_res)
90
+ cv2.waitKey(0)
91
+ cv2.destroyAllWindows()
92
+ elif self.save:
93
+ save_dir = self.save_path[:self.save_path.rfind('/')]
94
+ if not os.path.exists(save_dir):
95
+ os.makedirs(save_dir)
96
+ cv2.imwrite(self.save_path, vis_res)
97
+ return self.save_path
98
+
99
+ else:
100
+ return vis_res
101
+ else:
102
+ object_predictions_list = [bboxes, scores, cls, COCO_CLASSES]
103
+ return object_predictions_list
104
+
105
+
yoloxdetect2/utils/downloads.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import urllib
4
+ import requests
5
+ import subprocess
6
+
7
+ def attempt_download_from_hub(repo_id, hf_token=None):
8
+ # https://github.com/fcakyon/yolov5-pip/blob/main/yolov5/utils/downloads.py
9
+ from huggingface_hub import hf_hub_download, list_repo_files
10
+ from huggingface_hub.utils._errors import RepositoryNotFoundError
11
+ from huggingface_hub.utils._validators import HFValidationError
12
+ try:
13
+ repo_files = list_repo_files(repo_id=repo_id, repo_type='model', token=hf_token)
14
+ model_file = [f for f in repo_files if f.endswith('.pth')][0]
15
+ file = hf_hub_download(
16
+ repo_id=repo_id,
17
+ filename=model_file,
18
+ repo_type='model',
19
+ token=hf_token,
20
+ )
21
+ return file
22
+ except (RepositoryNotFoundError, HFValidationError):
23
+ return None
24
+
25
+
26
+ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
27
+ import os
28
+ # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
29
+
30
+ file = Path(file)
31
+ assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
32
+ try: # url1
33
+ torch.hub.download_url_to_file(url, str(file), progress=True) # pytorch download
34
+ assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
35
+ except Exception as e: # url2
36
+ file.unlink(missing_ok=True) # remove partial downloads
37
+ os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
38
+ finally:
39
+ if not file.exists() or file.stat().st_size < min_bytes: # check
40
+ file.unlink(missing_ok=True) # remove partial downloads
41
+ raise Exception(error_msg or assert_msg) # raise informative error
42
+
43
+ def attempt_download(file, repo='Megvii-BaseDetection/YOLOX', release='0.1.0'):
44
+ def github_assets(repository, version='latest'):
45
+ response = requests.get(f'https://api.github.com/repos/{repository}/releases/tags/{version}').json() # github api
46
+ return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
47
+
48
+ file = Path(str(file).strip().replace("'", ''))
49
+ if not file.exists():
50
+ # URL specified
51
+ name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
52
+ if str(file).startswith(('http:/', 'https:/')): # download
53
+ url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
54
+ file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
55
+ if Path(file).is_file():
56
+ return file
57
+ else:
58
+ safe_download(file=file, url=url, min_bytes=1E5)
59
+ return file
60
+
61
+ # GitHub assets
62
+ assets = [
63
+ 'yolov6n.pt', 'yolov6s.pt', 'yolov6m.pt', 'yolov6l.pt',
64
+ 'yolov6n6.pt', 'yolov6s6.pt', 'yolov6m6.pt', 'yolov6l6.pt']
65
+ try:
66
+ tag, assets = github_assets(repo, release)
67
+ except Exception:
68
+ try:
69
+ tag, assets = github_assets(repo) # latest release
70
+ except Exception:
71
+ try:
72
+ tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
73
+ except Exception:
74
+ tag = release
75
+
76
+ file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
77
+ if name in assets:
78
+ safe_download(
79
+ file,
80
+ url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
81
+ url2=f'https://storage.googleapis.com/{repo}/{tag}/{name}', # backup url (optional)
82
+ min_bytes=1E5,
83
+ error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
84
+
85
+ return str(file)