SignLanguage-pro / tools /models.py
thienphuc12339's picture
Update tools/models.py
74050d9 verified
# tools/models.py
import torch
import logging
import onnxruntime as ort
from time import time
from typing import Union
from configs import ModelConfig, InferenceConfig
from visualization import draw_text_on_image
from pipelines import VideoClassificationPipeline # Nếu cần thiết
import numpy as np
class Predictions:
def __init__(
self,
predictions: list[dict] = None,
inference_time: float = 0,
start_time: float = 0,
end_time: float = 0,
) -> None:
self.predictions = predictions
self.inference_time = inference_time
self.start_time = start_time
self.end_time = end_time
def visualize(
self,
frame: np.ndarray,
position: tuple = (20, 100),
prefix: str = "Predictions",
color: tuple = (0, 0, 255),
) -> np.ndarray:
text = prefix + ": " + self.get_pred_message()
return draw_text_on_image(
image=frame,
text=text,
position=position,
color=color,
font_size=20,
)
def get_pred_message(self) -> str:
if not any((
self.start_time,
self.end_time,
self.inference_time,
self.predictions
)):
return ""
return ', '.join(
[
f"{pred['gloss']} ({pred['score']*100:.2f}%)"
for pred in self.predictions
]
)
def __str__(self) -> str:
if not any((
self.start_time,
self.end_time,
self.inference_time,
self.predictions
)):
return ""
predictions = self.get_pred_message()
message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}"
return message.format(self.start_time, self.end_time, self.inference_time, predictions)
def merge_results(self, results: dict = None) -> dict:
if results is None:
results = {
"start_time": [],
"end_time": [],
"inference_time": [],
"prediction": [],
}
results["start_time"].append(self.start_time)
results["end_time"].append(self.end_time)
results["inference_time"].append(self.inference_time)
results["prediction"].append(self.predictions)
return results
def load_model(
model_config: ModelConfig,
inference_config: InferenceConfig,
label2id: dict = None,
id2label: dict = None,
) -> ort.InferenceSession:
'''
Tải mô hình ONNX sử dụng onnxruntime.
'''
try:
session = ort.InferenceSession(model_config.pretrained)
logging.info(f"ONNX model loaded from {model_config.pretrained}")
except Exception as e:
logging.error(f"Failed to load ONNX model: {e}")
raise e
return session
def load_pipeline(
model_config: ModelConfig,
inference_config: InferenceConfig,
) -> ort.InferenceSession:
'''
Tải onnxruntime session dựa trên cấu hình mô hình.
'''
session = load_model(model_config, inference_config)
return session
def preprocess_inputs_onnx(inputs: np.ndarray, processor=None) -> dict:
'''
Chuyển đổi đầu vào cho mô hình ONNX nếu cần.
Bạn có thể thêm các bước tiền xử lý cụ thể ở đây nếu cần.
'''
# Ví dụ: Đảm bảo rằng đầu vào có định dạng phù hợp
# inputs = processor(inputs) # Nếu cần thiết
return {"pixel_values": inputs.astype(np.float32)} # Điều chỉnh tùy thuộc vào yêu cầu của mô hình
def get_predictions(
inputs: np.ndarray,
model: ort.InferenceSession,
id2gloss: dict,
k: int = 3,
) -> Predictions:
'''
Lấy top-k dự đoán từ mô hình ONNX.
Parameters
----------
inputs : np.ndarray
Dữ liệu đầu vào đã được tiền xử lý.
model : ort.InferenceSession
Mô hình ONNX đã được tải.
id2gloss : dict
Bản đồ từ ID lớp sang gloss.
k : int, optional
Số lượng dự đoán cần trả về, mặc định là 3.
Returns
-------
Predictions
Đối tượng chứa các dự đoán và thời gian suy luận.
'''
if inputs is None:
return Predictions()
# Tiền xử lý đầu vào cho ONNX
preprocessed_inputs = preprocess_inputs_onnx(inputs)
# Lấy logits
start_time = time()
try:
logits = model.run(None, preprocessed_inputs)[0]
except Exception as e:
logging.error(f"Error during ONNX inference: {e}")
raise e
inference_time = time() - start_time
logits = torch.from_numpy(logits)
# Lấy top-k dự đoán
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
topk_indices = topk_indices.squeeze().detach().numpy()
predictions = []
for i in range(k):
class_idx = str(topk_indices[i])
gloss = id2gloss.get(class_idx, "Unknown")
score = topk_scores[i]
predictions.append({
'gloss': gloss,
'score': score,
})
return Predictions(predictions=predictions, inference_time=inference_time)