File size: 5,501 Bytes
0f395b3
a814720
 
 
2bcc89d
a814720
 
1c8464e
0dec8ff
a814720
 
 
2bcc89d
d5b6416
9b11f0c
a814720
2aaaec3
8ace1b1
7f16406
 
a814720
14b0abe
a814720
 
 
 
 
 
 
 
 
 
 
6e3c3b6
0f395b3
 
6e3c3b6
7f16406
 
 
c03d590
866bc7d
 
 
7f16406
 
14b0abe
 
9b11f0c
 
84a9466
6e3c3b6
 
0fda95a
84a9466
a814720
 
f5b8f7d
 
a814720
f5b8f7d
 
 
 
a814720
 
 
 
f5b8f7d
 
 
a814720
f5b8f7d
 
 
 
 
 
 
a814720
 
f5b8f7d
 
 
a814720
f5b8f7d
 
 
14b0abe
a814720
 
569757e
f5b8f7d
14b0abe
a814720
 
569757e
f5b8f7d
 
a814720
 
f5b8f7d
 
 
7f16406
 
f5b8f7d
 
 
87ca35b
 
f5b8f7d
 
 
87ca35b
f5b8f7d
 
 
 
87ca35b
 
7f16406
14b0abe
e92245a
c8c5ef3
 
db28ce6
311de00
 
db28ce6
 
 
 
 
 
f542926
9a3da79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""YOLOv5 PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/

Usage:
    import torch
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
"""

from pathlib import Path

import torch

from models.yolo import Model
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.torch_utils import select_device

dependencies = ['torch', 'yaml']
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop'))
set_logging()


def create(name, pretrained, channels, classes, autoshape):
    """Creates a specified YOLOv5 model

    Arguments:
        name (str): name of model, i.e. 'yolov5s'
        pretrained (bool): load pretrained weights into the model
        channels (int): number of input channels
        classes (int): number of model classes

    Returns:
        pytorch model
    """
    try:
        cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0]  # model.yaml path
        model = Model(cfg, channels, classes)
        if pretrained:
            fname = f'{name}.pt'  # checkpoint filename
            attempt_download(fname)  # download if not found locally
            ckpt = torch.load(fname, map_location=torch.device('cpu'))  # load
            msd = model.state_dict()  # model state_dict
            csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
            csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape}  # filter
            model.load_state_dict(csd, strict=False)  # load
            if len(ckpt['model'].names) == classes:
                model.names = ckpt['model'].names  # set class names attribute
            if autoshape:
                model = model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
        device = select_device('0' if torch.cuda.is_available() else 'cpu')  # default to GPU if available
        return model.to(device)

    except Exception as e:
        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
        s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
        raise Exception(s) from e


def custom(path_or_model='path/to/model.pt', autoshape=True):
    """YOLOv5-custom model https://github.com/ultralytics/yolov5

    Arguments (3 options):
        path_or_model (str): 'path/to/model.pt'
        path_or_model (dict): torch.load('path/to/model.pt')
        path_or_model (nn.Module): torch.load('path/to/model.pt')['model']

    Returns:
        pytorch model
    """
    model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model  # load checkpoint
    if isinstance(model, dict):
        model = model['ema' if model.get('ema') else 'model']  # load model

    hub_model = Model(model.yaml).to(next(model.parameters()).device)  # create
    hub_model.load_state_dict(model.float().state_dict())  # load state_dict
    hub_model.names = model.names  # class names
    if autoshape:
        hub_model = hub_model.autoshape()  # for file/URI/PIL/cv2/np inputs and NMS
    device = select_device('0' if torch.cuda.is_available() else 'cpu')  # default to GPU if available
    return hub_model.to(device)


def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-small model https://github.com/ultralytics/yolov5
    return create('yolov5s', pretrained, channels, classes, autoshape)


def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-medium model https://github.com/ultralytics/yolov5
    return create('yolov5m', pretrained, channels, classes, autoshape)


def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-large model https://github.com/ultralytics/yolov5
    return create('yolov5l', pretrained, channels, classes, autoshape)


def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
    return create('yolov5x', pretrained, channels, classes, autoshape)


def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-small model https://github.com/ultralytics/yolov5
    return create('yolov5s6', pretrained, channels, classes, autoshape)


def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-medium model https://github.com/ultralytics/yolov5
    return create('yolov5m6', pretrained, channels, classes, autoshape)


def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-large model https://github.com/ultralytics/yolov5
    return create('yolov5l6', pretrained, channels, classes, autoshape)


def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True):
    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
    return create('yolov5x6', pretrained, channels, classes, autoshape)


if __name__ == '__main__':
    model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True)  # pretrained example
    # model = custom(path_or_model='path/to/model.pt')  # custom example

    # Verify inference
    import numpy as np
    from PIL import Image

    imgs = [Image.open('data/images/bus.jpg'),  # PIL
            'data/images/zidane.jpg',  # filename
            'https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg',  # URI
            np.zeros((640, 480, 3))]  # numpy

    results = model(imgs)  # batched inference
    results.print()
    results.save()