|
import torch |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torchvision.transforms as T |
|
|
|
def inference_and_save(model, input_image_path, output_image_path, threshold=0.3, class_names=None): |
|
|
|
img = Image.open(input_image_path).convert("L") |
|
img = img.resize((128, 128)) |
|
img_tensor = T.ToTensor()(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
predictions = model(img_tensor.to(torch.device('cpu'))) |
|
|
|
|
|
fig, ax = plt.subplots(1, figsize=(10, 10)) |
|
img_np = np.array(img) |
|
ax.imshow(img_np, cmap='gray') |
|
|
|
for i, (box, score, label) in enumerate(zip(predictions[0]['boxes'], predictions[0]['scores'], predictions[0]['labels'])): |
|
if score > threshold: |
|
x1, y1, x2, y2 = map(int, box) |
|
rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='red', linewidth=2) |
|
ax.add_patch(rect) |
|
|
|
if class_names: |
|
class_name = class_names[label.item()] |
|
else: |
|
class_name = f"Class {label.item()}" |
|
|
|
ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='white', alpha=0.5)) |
|
|
|
ax.axis('off') |
|
plt.tight_layout() |
|
plt.savefig(output_image_path) |
|
plt.close() |
|
|
|
print(f'Result saved at {output_image_path}') |
|
|
|
|
|
input_image_path = '1.png' |
|
output_image_path = 'result_1.png' |
|
|
|
|
|
model_path = 'road_best_model.pt' |
|
model = torch.load(model_path, map_location=torch.device('cpu')) |
|
model.eval() |
|
|
|
|
|
class_names = {0: 'trafficlight', 1: 'speedlimit', 2: 'crosswalk', 3: 'stop'} |
|
|
|
|
|
inference_and_save(model, input_image_path, output_image_path, threshold=0.3, class_names=class_names) |