SAM-SLR-V1 / utils /model.py
votuongquan2004@gmail.com
init
e317171
raw
history blame contribute delete
No virus
1.09 kB
import torch
import numpy as np
import onnxruntime as ort
def get_predictions(
inputs: np.ndarray,
ort_session: ort.InferenceSession,
id2gloss: dict,
k: int = 3,
) -> list:
'''
Get the top-k predictions.
Parameters
----------
inputs : dict
Model inputs.
ort_session : ort.InferenceSession
ONNX Runtime session.
id2gloss : dict
Mapping from class index to class label.
k : int, optional
Number of predictions to return, by default 3.
Returns
-------
list
Top-k predictions.
'''
if inputs is None:
return []
logits = torch.from_numpy(ort_session.run(None, {'x': inputs})[0])
# Get top-3 predictions
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
topk_indices = topk_indices.squeeze().detach().numpy()
return [
{
'label': id2gloss[topk_indices[i]],
'score': topk_scores[i],
}
for i in range(k)
]