File size: 1,904 Bytes
c09670c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a1e34
c09670c
 
 
69a1e34
c09670c
 
69a1e34
c09670c
 
 
 
 
 
69a1e34
c09670c
 
1272aeb
69a1e34
 
 
c09670c
 
69a1e34
1272aeb
 
69a1e34
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os, timeit
from openvino.runtime import Core, CompiledModel
from openvino.tools.mo import convert_model
from PIL import Image
from super_gradients.common.object_names import Models
from super_gradients.training import models as Sg_models
from torch import nn
from torchvision import transforms as T

if __name__ == '__main__':
    models = []
    model = Sg_models.get(
        Models.YOLO_NAS_S,
        checkpoint_path='checkpoints/sg_yolonas/yolo-nas-plastic/RUN_20231025_154945_419512/ckpt_latest.pth',
        num_classes=2,
        checkpoint_num_classes=2)
    models.append(model)

    images = []
    base_dir = 'test_images'
    for f in os.listdir(base_dir):
        if not f.endswith('.bmp'):
            continue
        image = Image.open(f'{base_dir}/{f}').convert('RGB').resize((640, 480))
        images.append(image)

    # Convert to OpenVINO
    to_tensor = T.ToTensor()
    print('Converting to OpenVINO...')
    core = Core()
    model = convert_model(model, example_input=to_tensor(images[0]).unsqueeze(0))
    model = core.compile_model(model, 'AUTO')
    models.append(model)
    print('Converted to OpenVINO.')

    for model in models:
        print(type(model))
        count = 0
        for image in images:
            start_time = timeit.default_timer()
            if isinstance(model, CompiledModel):
                preds = model(to_tensor(image).unsqueeze(0))
                # TODO: Decode model output
                # refer super_gradients.training.pipelines.pipelines -> DetectionPipeline._decode_model_output
            elif isinstance(model, Sg_models.SgModule):
                preds = model.predict(image)
                count += len(preds[0].prediction)
            elif isinstance(model, nn.Module):
                preds = model(image)
            print(f'Time: {(timeit.default_timer() - start_time) * 100:.3f}ms')
        print(f'Count: {count}')