Spaces:
Runtime error
Runtime error
| # tools/models.py | |
| import torch | |
| import logging | |
| import onnxruntime as ort | |
| from time import time | |
| from typing import Union | |
| from configs import ModelConfig, InferenceConfig | |
| from visualization import draw_text_on_image | |
| from pipelines import VideoClassificationPipeline # Nếu cần thiết | |
| import numpy as np | |
| class Predictions: | |
| def __init__( | |
| self, | |
| predictions: list[dict] = None, | |
| inference_time: float = 0, | |
| start_time: float = 0, | |
| end_time: float = 0, | |
| ) -> None: | |
| self.predictions = predictions | |
| self.inference_time = inference_time | |
| self.start_time = start_time | |
| self.end_time = end_time | |
| def visualize( | |
| self, | |
| frame: np.ndarray, | |
| position: tuple = (20, 100), | |
| prefix: str = "Predictions", | |
| color: tuple = (0, 0, 255), | |
| ) -> np.ndarray: | |
| text = prefix + ": " + self.get_pred_message() | |
| return draw_text_on_image( | |
| image=frame, | |
| text=text, | |
| position=position, | |
| color=color, | |
| font_size=20, | |
| ) | |
| def get_pred_message(self) -> str: | |
| if not any(( | |
| self.start_time, | |
| self.end_time, | |
| self.inference_time, | |
| self.predictions | |
| )): | |
| return "" | |
| return ', '.join( | |
| [ | |
| f"{pred['gloss']} ({pred['score']*100:.2f}%)" | |
| for pred in self.predictions | |
| ] | |
| ) | |
| def __str__(self) -> str: | |
| if not any(( | |
| self.start_time, | |
| self.end_time, | |
| self.inference_time, | |
| self.predictions | |
| )): | |
| return "" | |
| predictions = self.get_pred_message() | |
| message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}" | |
| return message.format(self.start_time, self.end_time, self.inference_time, predictions) | |
| def merge_results(self, results: dict = None) -> dict: | |
| if results is None: | |
| results = { | |
| "start_time": [], | |
| "end_time": [], | |
| "inference_time": [], | |
| "prediction": [], | |
| } | |
| results["start_time"].append(self.start_time) | |
| results["end_time"].append(self.end_time) | |
| results["inference_time"].append(self.inference_time) | |
| results["prediction"].append(self.predictions) | |
| return results | |
| def load_model( | |
| model_config: ModelConfig, | |
| inference_config: InferenceConfig, | |
| label2id: dict = None, | |
| id2label: dict = None, | |
| ) -> ort.InferenceSession: | |
| ''' | |
| Tải mô hình ONNX sử dụng onnxruntime. | |
| ''' | |
| try: | |
| session = ort.InferenceSession(model_config.pretrained) | |
| logging.info(f"ONNX model loaded from {model_config.pretrained}") | |
| except Exception as e: | |
| logging.error(f"Failed to load ONNX model: {e}") | |
| raise e | |
| return session | |
| def load_pipeline( | |
| model_config: ModelConfig, | |
| inference_config: InferenceConfig, | |
| ) -> ort.InferenceSession: | |
| ''' | |
| Tải onnxruntime session dựa trên cấu hình mô hình. | |
| ''' | |
| session = load_model(model_config, inference_config) | |
| return session | |
| def preprocess_inputs_onnx(inputs: np.ndarray, processor=None) -> dict: | |
| ''' | |
| Chuyển đổi đầu vào cho mô hình ONNX nếu cần. | |
| Bạn có thể thêm các bước tiền xử lý cụ thể ở đây nếu cần. | |
| ''' | |
| # Ví dụ: Đảm bảo rằng đầu vào có định dạng phù hợp | |
| # inputs = processor(inputs) # Nếu cần thiết | |
| return {"pixel_values": inputs.astype(np.float32)} # Điều chỉnh tùy thuộc vào yêu cầu của mô hình | |
| def get_predictions( | |
| inputs: np.ndarray, | |
| model: ort.InferenceSession, | |
| id2gloss: dict, | |
| k: int = 3, | |
| ) -> Predictions: | |
| ''' | |
| Lấy top-k dự đoán từ mô hình ONNX. | |
| Parameters | |
| ---------- | |
| inputs : np.ndarray | |
| Dữ liệu đầu vào đã được tiền xử lý. | |
| model : ort.InferenceSession | |
| Mô hình ONNX đã được tải. | |
| id2gloss : dict | |
| Bản đồ từ ID lớp sang gloss. | |
| k : int, optional | |
| Số lượng dự đoán cần trả về, mặc định là 3. | |
| Returns | |
| ------- | |
| Predictions | |
| Đối tượng chứa các dự đoán và thời gian suy luận. | |
| ''' | |
| if inputs is None: | |
| return Predictions() | |
| # Tiền xử lý đầu vào cho ONNX | |
| preprocessed_inputs = preprocess_inputs_onnx(inputs) | |
| # Lấy logits | |
| start_time = time() | |
| try: | |
| logits = model.run(None, preprocessed_inputs)[0] | |
| except Exception as e: | |
| logging.error(f"Error during ONNX inference: {e}") | |
| raise e | |
| inference_time = time() - start_time | |
| logits = torch.from_numpy(logits) | |
| # Lấy top-k dự đoán | |
| topk_scores, topk_indices = torch.topk(logits, k, dim=1) | |
| topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() | |
| topk_indices = topk_indices.squeeze().detach().numpy() | |
| predictions = [] | |
| for i in range(k): | |
| class_idx = str(topk_indices[i]) | |
| gloss = id2gloss.get(class_idx, "Unknown") | |
| score = topk_scores[i] | |
| predictions.append({ | |
| 'gloss': gloss, | |
| 'score': score, | |
| }) | |
| return Predictions(predictions=predictions, inference_time=inference_time) | |