Spaces:
Build error
Build error
import copy | |
import torch | |
import numpy as np | |
import gradio as gr | |
from spoter_mod.skeleton_extractor import obtain_pose_data | |
from spoter_mod.normalization.body_normalization import normalize_single_dict as normalize_single_body_dict, BODY_IDENTIFIERS | |
from spoter_mod.normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict, HAND_IDENTIFIERS | |
model = torch.load("spoter-checkpoint.pth", map_location=torch.device('cpu')) | |
model.train(False) | |
HAND_IDENTIFIERS = [id + "_Left" for id in HAND_IDENTIFIERS] + [id + "_Right" for id in HAND_IDENTIFIERS] | |
GLOSS = ['book', 'drink', 'computer', 'before', 'chair', 'go', 'clothes', 'who', 'candy', 'cousin', 'deaf', 'fine', | |
'help', 'no', 'thin', 'walk', 'year', 'yes', 'all', 'black', 'cool', 'finish', 'hot', 'like', 'many', 'mother', | |
'now', 'orange', 'table', 'thanksgiving', 'what', 'woman', 'bed', 'blue', 'bowling', 'can', 'dog', 'family', | |
'fish', 'graduate', 'hat', 'hearing', 'kiss', 'language', 'later', 'man', 'shirt', 'study', 'tall', 'white', | |
'wrong', 'accident', 'apple', 'bird', 'change', 'color', 'corn', 'cow', 'dance', 'dark', 'doctor', 'eat', | |
'enjoy', 'forget', 'give', 'last', 'meet', 'pink', 'pizza', 'play', 'school', 'secretary', 'short', 'time', | |
'want', 'work', 'africa', 'basketball', 'birthday', 'brown', 'but', 'cheat', 'city', 'cook', 'decide', 'full', | |
'how', 'jacket', 'letter', 'medicine', 'need', 'paint', 'paper', 'pull', 'purple', 'right', 'same', 'son', | |
'tell', 'thursday'] | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
def tensor_to_dictionary(landmarks_tensor: torch.Tensor) -> dict: | |
data_array = landmarks_tensor.numpy() | |
output = {} | |
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): | |
output[identifier] = data_array[:, landmark_index] | |
return output | |
def dictionary_to_tensor(landmarks_dict: dict) -> torch.Tensor: | |
output = np.empty(shape=(len(landmarks_dict["leftEar"]), len(BODY_IDENTIFIERS + HAND_IDENTIFIERS), 2)) | |
for landmark_index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): | |
output[:, landmark_index, 0] = [frame[0] for frame in landmarks_dict[identifier]] | |
output[:, landmark_index, 1] = [frame[1] for frame in landmarks_dict[identifier]] | |
return torch.from_numpy(output) | |
def greet(label, video0, video1): | |
if label == "Webcam": | |
video = video0 | |
elif label == "Video": | |
video = video1 | |
elif label == "X": | |
return {"A": 0.8, "B": 0.1, "C": 0.1} | |
else: | |
return {} | |
data = obtain_pose_data(video) | |
depth_map = np.empty(shape=(len(data.data_hub["nose_X"]), len(BODY_IDENTIFIERS + HAND_IDENTIFIERS), 2)) | |
for index, identifier in enumerate(BODY_IDENTIFIERS + HAND_IDENTIFIERS): | |
depth_map[:, index, 0] = data.data_hub[identifier + "_X"] | |
depth_map[:, index, 1] = data.data_hub[identifier + "_Y"] | |
depth_map = torch.from_numpy(np.copy(depth_map)) | |
depth_map = tensor_to_dictionary(depth_map) | |
keys = copy.copy(list(depth_map.keys())) | |
for key in keys: | |
data = depth_map[key] | |
del depth_map[key] | |
depth_map[key.replace("_Left", "_0").replace("_Right", "_1")] = data | |
depth_map = normalize_single_body_dict(depth_map) | |
depth_map = normalize_single_hand_dict(depth_map) | |
keys = copy.copy(list(depth_map.keys())) | |
for key in keys: | |
data = depth_map[key] | |
del depth_map[key] | |
depth_map[key.replace("_0", "_Left").replace("_1", "_Right")] = data | |
depth_map = dictionary_to_tensor(depth_map) | |
depth_map = depth_map - 0.5 | |
inputs = depth_map.squeeze(0).to(device) | |
outputs = model(inputs).expand(1, -1, -1) | |
results = torch.nn.functional.softmax(outputs, dim=2).detach().numpy()[0, 0] | |
results = {GLOSS[i]: float(results[i]) for i in range(100)} | |
return results | |
label = gr.outputs.Label(num_top_classes=5, label="Top class probabilities") | |
demo = gr.Interface(fn=greet, inputs=[gr.Dropdown(["Webcam", "Video"], label="Please select the input type:", type="value"), gr.Video(source="webcam", label="Webcam recording", type="mp4"), gr.Video(source="upload", label="Video upload", type="mp4")], outputs=label, | |
title="🤟 SPOTER Sign language recognition", | |
description="""Try out our recent model for sign language recognition right in your browser! The model below takes a video of a single sign in the American Sign Language at the input and provides you with probabilities of the lemmas (equivalent to words in natural language). | |
### Our work at CVPR | |
Our efforts on lightweight and efficient models for sign language recognition were first introduced at WACV with our SPOTER paper. We now presented a work-in-progress follow-up here at CVPR's AVA workshop. Be sure to check our work and code below: | |
- **WACV2022** - Original SPOTER paper - [Paper](), [Code]() | |
- **CVPR2022 AVA Worshop** - Follow-up WIP – [Extended Abstract](), [Poster]() | |
### How to sign? | |
The model wrapped in this demo was trained on [WLASL100](https://dxli94.github.io/WLASL/), so it only knows selected ASL vocabulary. Take a look at these tutorial video examples, try to replicate them yourself, and have them recognized using the webcam capture below. Have fun!""", | |
article="This is joint work of [Matyas Bohacek](https://scholar.google.cz/citations?user=wDy1xBwAAAAJ) and [Zhuo Cao](https://www.linkedin.com/in/zhuo-cao-b0787a1aa/?originalSubdomain=hk). For more info, visit [our website.](https://www.signlanguagerecognition.com)", | |
css=""" | |
@font-face { | |
font-family: Graphik; | |
font-weight: regular; | |
src: url("https://www.signlanguagerecognition.com/supplementary/GraphikRegular.otf") format("opentype"); | |
} | |
@font-face { | |
font-family: Graphik; | |
font-weight: bold; | |
src: url("https://www.signlanguagerecognition.com/supplementary/GraphikBold.otf") format("opentype"); | |
} | |
@font-face { | |
font-family: MonumentExpanded; | |
font-weight: regular; | |
src: url("https://www.signlanguagerecognition.com/supplementary/MonumentExtended-Regular.otf") format("opentype"); | |
} | |
@font-face { | |
font-family: MonumentExpanded; | |
font-weight: bold; | |
src: url("https://www.signlanguagerecognition.com/supplementary/MonumentExtended-Bold.otf") format("opentype"); | |
} | |
html { | |
font-family: "Graphik"; | |
} | |
h1 { | |
font-family: "MonumentExpanded"; | |
} | |
#12 { | |
- background-image: linear-gradient(to left, #61D836, #6CB346) !important; | |
background-color: #61D836 !important; | |
} | |
#12:hover { | |
- background-image: linear-gradient(to left, #61D836, #6CB346) !important; | |
background-color: #6CB346 !important; | |
border: 0 !important; | |
border-color: 0 !important; | |
} | |
.dark .gr-button-primary { | |
--tw-gradient-from: #61D836; | |
--tw-gradient-to: #6CB346; | |
border: 0 !important; | |
border-color: 0 !important; | |
} | |
.dark .gr-button-primary:hover { | |
--tw-gradient-from: #64A642; | |
--tw-gradient-to: #58933B; | |
border: 0 !important; | |
border-color: 0 !important; | |
} | |
""", | |
cache_examples=True | |
) | |
demo.launch(debug=True) | |