File size: 2,503 Bytes
bdd9241
 
 
 
 
 
 
 
 
 
 
 
bb3e852
 
e69d1ae
 
bb3e852
 
bdd9241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69d1ae
bdd9241
 
 
5d8ed5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdd9241
 
 
 
5d8ed5e
 
 
bdd9241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import numpy as np

def plot_bev_detections(
    boxes_3d: torch.Tensor,
    scores: torch.Tensor,
    labels: torch.Tensor,
    score_thresh: float = 0.1,
    save_path: str = None
):

    class_names = [
        'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'pedestrian', 
        'motorcycle', 'bicycle', 'traffic_cone', 'car'
    ]

    # 1) Create figure & axes
    fig, ax = plt.subplots(figsize=(12, 12))

    # 2) Draw ego vehicle at origin
    ax.add_patch(patches.Rectangle(
        (-1, -1), 2, 2,
        linewidth=1,
        edgecolor='black',
        facecolor='gray',
        label='Ego Vehicle'
    ))

    # 3) Filter by score
    mask = scores >= score_thresh
    boxes, scores, labels = boxes_3d[mask], scores[mask], labels[mask]

    # 4) Prepare a color for each class
    cmap = plt.get_cmap('tab10')  # up to 10 distinct colors
    num_classes = len(class_names)
    colors = {i: cmap(i % 10) for i in range(num_classes)}
    
    # 5) Draw each box
    seen_labels = set()
    for box, score, label in zip(boxes, scores, labels):
        if label != 1: 
            x, y, z, dx, dy, dz, yaw, *_ = box.cpu().numpy()
            cls_idx = int(label)
            cls_name = class_names[cls_idx]
            color = colors[cls_idx]
    
            # example of stretching length for 'car' if you still want it
            if cls_name.lower() == 'car':
                dx *= 1.2
    
            rect = patches.Rectangle(
                (x - dx/2, y - dy/2),
                dx, dy,
                angle=np.degrees(yaw),
                linewidth=1.5,
                edgecolor=color,
                facecolor='none'
            )
            ax.add_patch(rect)
    
            # remember we saw this label so we can add it to legend once
            seen_labels.add(cls_idx)

    # 6) Legend only for seen classes
    legend_handles = []
    for cls_idx in sorted(seen_labels):
        legend_handles.append(
            patches.Patch(color=colors[cls_idx], label=class_names[cls_idx])
        )
    ax.legend(handles=legend_handles, loc='upper right')

    # 7) Axes limits and labels
    ax.set_xlim(-50, 50)
    ax.set_ylim(-50, 50)
    ax.set_xlabel('X (meters)')
    ax.set_ylabel('Y (meters)')
    ax.set_title('BEV Detections')

    # 8) Save or show
    if save_path:
        fig.savefig(save_path, bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()