| |
| |
| """ |
| ONNX Runtime Bird Classification Inference Script (Top-5 Enhanced) |
| Loads an exported ONNX model for bird classification. |
| Defaults to CPU execution. |
| """ |
| import os |
| import argparse |
| import numpy as np |
| import cv2 |
| from PIL import Image |
| import axengine as axe |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from tqdm import tqdm |
| |
| plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif'] |
| plt.rcParams['axes.unicode_minus'] = False |
|
|
| class BirdPredictorONNX: |
| """Bird classification predictor based on ONNX Runtime""" |
|
|
| def __init__(self, class_name_file, model_file, mean, std, image_size=224): |
| """ |
| Initialize the predictor. |
| Defaults to AxEngineExecutionProvider. |
| """ |
| self.rgb_mean = mean |
| self.rgb_std = std |
| self.image_size = image_size |
| self.classes = self.load_classes(class_name_file) |
| print(f"build predictor with {model_file}...") |
| providers = ['AxEngineExecutionProvider'] |
| print(f"Loading ONNX model with providers: {providers}") |
| |
| try: |
| self.session = axe.InferenceSession(model_file, providers=providers) |
| except Exception as e: |
| print(f"Failed to load model: {e}") |
| raise |
| |
| self.input_name = self.session.get_inputs()[0].name |
| self.input_shape = self.session.get_inputs()[0].shape |
| |
| self.transform = self.get_transform_params() |
| |
| def load_classes(self,class_name_file): |
| with open(class_name_file, 'r', encoding='utf-8') as f: |
| classes = [line.strip() for line in f.readlines() if line.strip()] |
| return classes |
| |
| def get_transform_params(self): |
| mean = np.array(self.rgb_mean, dtype=np.float32).reshape(1, 3, 1, 1) |
| std = np.array(self.rgb_std, dtype=np.float32).reshape(1, 3, 1, 1) |
| return {'mean': mean, 'std': std} |
| |
| def preprocess_image(self, image_path): |
| image = Image.open(image_path).convert('RGB') |
| image = image.resize((int(self.image_size), int(self.image_size)), Image.BICUBIC) |
|
|
| img_array = np.array(image, dtype=np.uint8) |
| img_array = img_array.transpose(2, 0, 1) |
| img_array = np.expand_dims(img_array, axis=0) |
| |
| return img_array |
| |
| def predict_image_topk(self, image_path, k=5): |
| input_data = self.preprocess_image(image_path) |
| outputs = self.session.run(None, {self.input_name: input_data}) |
| |
| logits = outputs[0] |
| exp_scores = np.exp(logits - np.max(logits, axis=1, keepdims=True)) |
| probabilities = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) |
| |
| probs_0 = probabilities[0] |
| top_k_indices = np.argsort(probs_0)[::-1][:k] |
| |
| results = [] |
| for idx in top_k_indices: |
| class_name = self.classes[idx] |
| conf = float(probs_0[idx]) |
| results.append((class_name, conf)) |
| |
| return results |
| |
| def predict_batch_topk(self, image_dir, k=5): |
| results = [] |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} |
| |
| files = sorted([f for f in os.listdir(image_dir) if any(f.lower().endswith(ext) for ext in image_extensions)]) |
| print(f"Found {len(files)} images, starting inference (Top-{k})...") |
| |
| for filename in tqdm(files): |
| image_path = os.path.join(image_dir, filename) |
| try: |
| top_k_results = self.predict_image_topk(image_path, k=k) |
| results.append({ |
| 'filename': filename, |
| 'path': image_path, |
| 'top_k': top_k_results |
| }) |
| except Exception as e: |
| print(f"Error processing image {filename}: {str(e)}") |
| |
| return results |
| |
| def _wrap_text(self, text, max_chars=25): |
| """ |
| Helper function to wrap or truncate long text to fit in table cells. |
| Tries to break at underscores or hyphens first. |
| """ |
| if len(text) <= max_chars: |
| return text |
| |
| |
| break_points = [i for i, char in enumerate(text[:max_chars]) if char in ['_', '-']] |
| |
| if break_points: |
| |
| split_idx = break_points[-1] + 1 |
| return text[:split_idx] + "\n" + text[split_idx:] |
| |
| |
| mid = max_chars // 2 |
| return text[:mid] + "-\n" + text[mid:] |
|
|
| def visualize_prediction_topk(self, image_path, top_k_results, save_path=None): |
| image = cv2.imread(image_path) |
| if image is None: |
| raise ValueError(f"Cannot read image: {image_path}") |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7)) |
| |
| ax1.imshow(image) |
| ax1.set_title('Input Image', fontsize=14, fontweight='bold') |
| ax1.axis('off') |
| |
| ax2.axis('off') |
| |
| table_data = [] |
| table_data.append(["Rank", "Class Name", "Confidence"]) |
| |
| processed_rows = [] |
| for i, (cls_name, conf) in enumerate(top_k_results): |
| rank = f"#{i+1}" |
| conf_str = f"{conf:.4f} ({conf*100:.2f}%)" |
| |
| |
| wrapped_name = self._wrap_text(cls_name, max_chars=28) |
| processed_rows.append([rank, wrapped_name, conf_str]) |
| |
| |
| full_table_data = [table_data[0]] + processed_rows |
| |
| |
| |
| table = ax2.table(cellText=full_table_data[1:], |
| colLabels=full_table_data[0], |
| loc='center', |
| cellLoc='left', |
| colWidths=[0.1, 0.6, 0.3], |
| bbox=[0.05, 0.1, 0.9, 0.75]) |
| |
| table.auto_set_font_size(False) |
| |
| |
| base_font_size = 10 |
| if any('\n' in row[1] for row in processed_rows): |
| base_font_size = 8 |
| |
| table.set_fontsize(base_font_size) |
| |
| |
| |
| row_scale = 1.8 if any('\n' in row[1] for row in processed_rows) else 1.5 |
| table.scale(1, row_scale) |
| |
| |
| for i in range(3): |
| cell = table[(0, i)] |
| cell.set_text_props(fontweight='bold', color='white', ha='center') |
| cell.set_facecolor('#4472C4') |
| if i == 1: |
| cell.set_text_props(ha='center') |
| |
| |
| for i in range(1, len(full_table_data)): |
| for j in range(3): |
| cell = table[(i, j)] |
| cell.set_facecolor('#ffffff' if i % 2 == 0 else '#f9f9f9') |
| cell.set_edgecolor('#dddddd') |
| cell.set_linewidth(1) |
| |
| |
| if j == 0: |
| cell.set_text_props(ha='center', va='center') |
| elif j == 1: |
| cell.set_text_props(ha='left', va='top', wrap=True) |
| else: |
| cell.set_text_props(ha='center', va='center') |
|
|
| |
| display_path = image_path |
| if len(display_path) > 50: |
| display_path = "..." + display_path[-47:] |
| |
| path_text = f"File Path:\n{display_path}" |
| ax2.text(0.5, 0.92, path_text, |
| ha='center', va='center', fontsize=9, color='#555555', |
| bbox=dict(boxstyle="round,pad=0.5", fc="#eeeeee", ec="#cccccc", alpha=0.8)) |
| |
| ax2.set_title('Top-5 Prediction Results', fontsize=14, fontweight='bold', pad=20) |
| |
| plt.tight_layout() |
| |
| out_path = save_path if save_path else 'prediction_result_top5.png' |
| plt.savefig(out_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Result saved to: {out_path}") |
|
|
| def calculate_batch_accuracy(self, batch_results, ground_truth_mapping=None): |
| """ |
| 计算批量图片的top1和top5准确率 |
| :param batch_results: predict_batch_topk返回的结果列表 |
| :param ground_truth_mapping: 可选,字典格式 {文件名: 真实类别名} |
| 如果未提供,默认从文件名中提取(假设文件名前缀是类别名,下划线/连字符分隔) |
| :return: 包含top1/top5准确率的字典 |
| """ |
| if not batch_results: |
| return {"top1_acc": 0.0, "top5_acc": 0.0, "total_images": 0, "correct_top1": 0, "correct_top5": 0} |
| |
| |
| gt_mapping = {} |
| if ground_truth_mapping is not None: |
| gt_mapping = ground_truth_mapping |
| else: |
| |
| for res in batch_results: |
| filename = res['filename'] |
| |
| name_without_ext = os.path.splitext(filename)[0] |
| |
| split_chars = ['_', '-', ' '] |
| gt_class = name_without_ext |
| for char in split_chars: |
| if char in gt_class: |
| gt_class = gt_class.split(char)[0] |
| break |
| gt_mapping[filename] = gt_class.strip() |
| |
| total = len(batch_results) |
| correct_top1 = 0 |
| correct_top5 = 0 |
| |
| |
| for res in batch_results: |
| filename = res['filename'] |
| true_class = gt_mapping.get(filename, "") |
| |
| top_k_preds = [cls_name for cls_name, _ in res['top_k']] |
| |
| |
| if true_class and true_class in top_k_preds[0]: |
| correct_top1 += 1 |
| |
| |
| if true_class and true_class in top_k_preds[:5]: |
| correct_top5 += 1 |
| |
| |
| top1_acc = correct_top1 / total if total > 0 else 0.0 |
| top5_acc = correct_top5 / total if total > 0 else 0.0 |
| |
| return { |
| "total_images": total, |
| "correct_top1": correct_top1, |
| "correct_top5": correct_top5, |
| "top1_acc": round(top1_acc, 4), |
| "top5_acc": round(top5_acc, 4), |
| "top1_acc_pct": f"{top1_acc*100:.2f}%", |
| "top5_acc_pct": f"{top5_acc*100:.2f}%" |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="ONNX Runtime Bird Classification (Top-5)") |
| parser.add_argument("-c", "--class_map_file", |
| default="./class_name.txt", |
| help="Path to configuration file") |
| parser.add_argument("-m", "--model_file", |
| default="./bird_rec.axmodel", |
| help="Path to model file") |
| parser.add_argument("-imgsz", "--image_size", |
| default=224, |
| help="Input image size") |
| parser.add_argument("-mean", "--mean", |
| type=float, |
| nargs='+', |
| default=[0.485, 0.456, 0.406], |
| help="Mean normalization values") |
| parser.add_argument("-std", "--std", |
| type=float, |
| nargs='+', |
| default=[0.229, 0.224, 0.225], |
| help="Standard deviation normalization values") |
| parser.add_argument("--image_dir", |
| default="./valid", |
| help="Directory containing test images") |
| parser.add_argument("--image", |
| help="Path to a single test image") |
| parser.add_argument("--top_k", |
| type=int, |
| default=5, |
| help="Number of top predictions to show (default: 5)") |
| |
| parser.add_argument("--gt_file", |
| default='./val_list_flat.txt', |
| help="可选,真实标签文件路径,格式:每行 '文件名 真实类别名'") |
| |
| args = parser.parse_args() |
| |
| |
| ground_truth = None |
| if args.gt_file and os.path.exists(args.gt_file): |
| ground_truth = {} |
| with open(args.gt_file, 'r', encoding='utf-8') as f: |
| for line in f.readlines(): |
| line = line.strip() |
| if not line: |
| continue |
| parts = line.split(maxsplit=1) |
| if len(parts) >= 2: |
| ground_truth[parts[0]] = parts[1] |
| print(f"Ground truth loaded from {args.gt_file}") |
| |
| predictor = BirdPredictorONNX(args.class_map_file, args.model_file, args.mean, args.std, args.image_size) |
| |
| if args.image and os.path.exists(args.image): |
| try: |
| top_k_results = predictor.predict_image_topk(args.image, k=args.top_k) |
| |
| print(f"\nImage: {args.image}") |
| print(f"Top-{args.top_k} Predictions:") |
| for i, (cls_name, conf) in enumerate(top_k_results): |
| print(f"#{i+1}: {cls_name} ({conf:.4f})") |
| |
| |
| |
| except Exception as e: |
| print(f"Inference failed: {e}") |
| |
| elif os.path.exists(args.image_dir): |
| results = predictor.predict_batch_topk(args.image_dir, k=args.top_k) |
| |
| |
| accuracy_stats = predictor.calculate_batch_accuracy(results, ground_truth) |
| |
| print(f"\n=== 批量推理结果汇总 ===") |
| print(f"总处理图片数: {accuracy_stats['total_images']}") |
| print(f"Top1正确数: {accuracy_stats['correct_top1']} | Top1准确率: {accuracy_stats['top1_acc_pct']}") |
| print(f"Top5正确数: {accuracy_stats['correct_top5']} | Top5准确率: {accuracy_stats['top5_acc_pct']}") |
| print(f"========================\n") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| else: |
| print("Specified image or directory not found.") |
|
|
| if __name__ == "__main__": |
| main() |