SAM-SLR-V1 / app.py
votuongquan2004@gmail.com
update app.py
0817cbf
raw
history blame
No virus
2.4 kB
import time
import torch
import pandas as pd
import gradio as gr
import onnxruntime as ort
from mediapipe.python.solutions import holistic
from utils.model import get_predictions
from utils.data import preprocess
title = '''
'''
cite_markdown = '''
'''
description = '''
'''
examples = []
# Load the configuration file.
ort_session = ort.InferenceSession('VSL_SAM_SLR_V2.onnx')
# Load id-to-gloss mapping.
id2gloss = pd.read_csv('gloss.csv', names=['id', 'gloss']).to_dict()['gloss']
def inference(
video: str,
progress: gr.Progress = gr.Progress(),
) -> str:
'''
Video-based inference for Vietnamese Sign Language recognition.
Parameters
----------
video : str
The path to the video.
progress : gr.Progress, optional
The progress bar, by default gr.Progress()
Returns
-------
str
The inference message.
'''
keypoints_detector = holistic.Holistic(
static_image_mode=False,
model_complexity=2,
enable_segmentation=True,
refine_face_landmarks=True,
)
progress(0, desc='Preprocessing video')
start_time = time.time()
inputs = preprocess(
source=video,
keypoints_detector=keypoints_detector,
)
end_time = time.time()
data_time = end_time - start_time
progress(1/2, desc='Getting predictions')
start_time = time.time()
predictions = get_predictions(
inputs=inputs, ort_session=ort_session, id2gloss=id2gloss, k=3
)
end_time = time.time()
model_time = end_time - start_time
if len(predictions) == 0:
output_message = 'No sign language detected in the video. Please try again.'
else:
output_message = 'The top-3 predictions are:\n'
for i, prediction in enumerate(predictions):
output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
output_message += f'Data processing time: {data_time:.2f} seconds\n'
output_message += f'Model inference time: {model_time:.2f} seconds\n'
output_message += f'Total time: {data_time + model_time:.2f} seconds'
progress(1/2, desc='Completed')
return output_message
iface = gr.Interface(
fn=inference,
inputs='video',
outputs='text',
examples=examples,
title=title,
description=description,
)
iface.launch()
# print(inference('000_con_cho.mp4'))