SAM-SLR-V1 / app.py
votuongquan2004@gmail.com
update app.py
1baa500
import time
import torch
import pandas as pd
import gradio as gr
import onnxruntime as ort
from mediapipe.python.solutions import holistic
from utils.model import get_predictions
from utils.data import preprocess
title = '''
'''
cite_markdown = '''
'''
description = '''
'''
examples = []
# Load the configuration file.
ort_session = ort.InferenceSession('VSL_SAM_SLR_V2_joint.onnx')
# Load id-to-gloss mapping.
id2gloss = pd.read_csv('gloss.csv', names=['id', 'gloss']).to_dict()['gloss']
def inference(
video: str,
progress: gr.Progress = gr.Progress(),
) -> str:
'''
Video-based inference for Vietnamese Sign Language recognition.
Parameters
----------
video : str
The path to the video.
progress : gr.Progress, optional
The progress bar, by default gr.Progress()
Returns
-------
str
The inference message.
'''
keypoints_detector = holistic.Holistic(
static_image_mode=False,
model_complexity=2,
enable_segmentation=True,
refine_face_landmarks=True,
)
progress(0, desc='Preprocessing video')
start_time = time.time()
inputs = preprocess(
source=video,
keypoints_detector=keypoints_detector,
)
end_time = time.time()
data_time = end_time - start_time
progress(1/2, desc='Getting predictions')
start_time = time.time()
predictions = get_predictions(
inputs=inputs, ort_session=ort_session, id2gloss=id2gloss, k=3
)
end_time = time.time()
model_time = end_time - start_time
if len(predictions) == 0:
output_message = 'No sign language detected in the video. Please try again.'
else:
output_message = 'The top-3 predictions are:\n'
for i, prediction in enumerate(predictions):
output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
output_message += f'Data processing time: {data_time:.2f} seconds\n'
output_message += f'Model inference time: {model_time:.2f} seconds\n'
output_message += f'Total time: {data_time + model_time:.2f} seconds'
output_message += f'\nInput shape: {inputs.shape}'
progress(1/2, desc='Completed')
return output_message
iface = gr.Interface(
fn=inference,
inputs='video',
outputs='text',
examples=examples,
title=title,
description=description,
)
iface.launch()
# print(inference('000_con_cho.mp4'))