File size: 5,337 Bytes
0f395b3
a814720
 
 
2bcc89d
a814720
 
1c8464e
0dec8ff
a814720
 
d08575e
2bcc89d
d5b6416
9b11f0c
a814720
2aaaec3
a833ee2
7f16406
a814720
7c89c82
a814720
 
 
 
 
 
 
d08575e
 
a814720
 
d08575e
a814720
d08575e
c1a44ed
6e3c3b6
d08575e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b11f0c
 
84a9466
6e3c3b6
 
d08575e
84a9466
a814720
 
c1a44ed
 
31ee54c
a814720
 
c0d3f80
f5b8f7d
c0d3f80
a814720
f5b8f7d
c0d3f80
f5b8f7d
c0d3f80
a814720
 
c0d3f80
f5b8f7d
c0d3f80
a814720
 
c0d3f80
f5b8f7d
c0d3f80
a814720
 
c0d3f80
 
 
7f16406
 
c0d3f80
 
 
87ca35b
 
c0d3f80
 
 
87ca35b
f5b8f7d
c0d3f80
 
 
87ca35b
 
7f16406
7c89c82
 
c8c5ef3
 
c15e25c
db28ce6
311de00
 
c15e25c
 
 
 
 
db28ce6
 
f542926
9a3da79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""YOLOv5 PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/

Usage:
    import torch
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""

from pathlib import Path

import torch

from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device

dependencies = ['torch', 'yaml']
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('tensorboard', 'pycocotools', 'thop'))


def create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    """Creates a specified YOLOv5 model

    Arguments:
        name (str): name of model, i.e. 'yolov5s'
        pretrained (bool): load pretrained weights into the model
        channels (int): number of input channels
        classes (int): number of model classes
        autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
        verbose (bool): print all information to screen

    Returns:
        YOLOv5 pytorch model
    """
    set_logging(verbose=verbose)
    fname = Path(name).with_suffix('.pt')  # checkpoint filename
    try:
        if pretrained and channels == 3 and classes == 80:
            model = attempt_load(fname, map_location=torch.device('cpu'))  # download/load FP32 model
        else:
            cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0]  # model.yaml path
            model = Model(cfg, channels, classes)  # create model
            if pretrained:
                attempt_download(fname)  # download if not found locally
                ckpt = torch.load(fname, map_location=torch.device('cpu'))  # load
                msd = model.state_dict()  # model state_dict
                csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
                csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape}  # filter
                model.load_state_dict(csd, strict=False)  # load
                if len(ckpt['model'].names) == classes:
                    model.names = ckpt['model'].names  # set class names attribute
        if autoshape:
            model = model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
        device = select_device('0' if torch.cuda.is_available() else 'cpu')  # default to GPU if available
        return model.to(device)

    except Exception as e:
        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
        s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
        raise Exception(s) from e


def custom(path='path/to/model.pt', autoshape=True, verbose=True):
    # YOLOv5 custom or local model
    return create(path, autoshape=autoshape, verbose=verbose)


def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-small model https://github.com/ultralytics/yolov5
    return create('yolov5s', pretrained, channels, classes, autoshape, verbose)


def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-medium model https://github.com/ultralytics/yolov5
    return create('yolov5m', pretrained, channels, classes, autoshape, verbose)


def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-large model https://github.com/ultralytics/yolov5
    return create('yolov5l', pretrained, channels, classes, autoshape, verbose)


def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
    return create('yolov5x', pretrained, channels, classes, autoshape, verbose)


def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
    return create('yolov5s6', pretrained, channels, classes, autoshape, verbose)


def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
    return create('yolov5m6', pretrained, channels, classes, autoshape, verbose)


def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
    return create('yolov5l6', pretrained, channels, classes, autoshape, verbose)


def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
    # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
    return create('yolov5x6', pretrained, channels, classes, autoshape, verbose)


if __name__ == '__main__':
    model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)  # pretrained
    # model = custom(path='path/to/model.pt')  # custom

    # Verify inference
    import cv2
    import numpy as np
    from PIL import Image

    imgs = ['data/images/zidane.jpg',  # filename
            'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg',  # URI
            cv2.imread('data/images/bus.jpg')[:, :, ::-1],  # OpenCV
            Image.open('data/images/bus.jpg'),  # PIL
            np.zeros((320, 640, 3))]  # numpy

    results = model(imgs)  # batched inference
    results.print()
    results.save()