Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import os | |
from collections import Counter | |
# Hugging Face Hub credentials | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation" # Replace with your repo ID | |
# Device configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Define new CNN model | |
IMG_HEIGHT = 28 | |
IMG_WIDTH = 28 | |
IMG_CHS = 1 | |
N_CLASSES = 24 # Adjusted to 24 classes | |
class MyConvBlock(nn.Module): | |
def __init__(self, in_ch, out_ch, dropout_p): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(out_ch), | |
nn.ReLU(), | |
nn.Dropout(dropout_p), | |
nn.MaxPool2d(2, stride=2) | |
) | |
def forward(self, x): | |
return self.model(x) | |
flattened_img_size = 75 * 3 * 3 # Adjusted based on the final output size | |
class ASL_CNN(nn.Module): | |
def __init__(self): | |
super(ASL_CNN, self).__init__() | |
self.base_model = nn.Sequential( | |
MyConvBlock(IMG_CHS, 25, 0), # 25 x 14 x 14 | |
MyConvBlock(25, 50, 0.2), # 50 x 7 x 7 | |
MyConvBlock(50, 75, 0), # 75 x 3 x 3 | |
nn.Flatten(), | |
nn.Linear(flattened_img_size, 512), | |
nn.Dropout(0.3), | |
nn.ReLU(), | |
nn.Linear(512, N_CLASSES) | |
) | |
def forward(self, x): | |
return self.base_model(x) | |
# Load the model from Hugging Face Hub | |
def load_model(): | |
model = ASL_CNN().to(device) | |
model_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/pytorch_model.bin" | |
response = requests.get(model_url) | |
state_dict = torch.load(BytesIO(response.content), map_location=device) | |
model.load_state_dict(state_dict, strict=False) # Set strict=False to ignore non-matching keys | |
model.eval() | |
return model | |
# Preprocess the image | |
def preprocess_image(image): | |
# Crop the image to a square | |
width, height = image.size | |
min_dim = min(width, height) | |
left = (width - min_dim) / 2 | |
top = (height - min_dim) / 2 | |
right = (width + min_dim) / 2 | |
bottom = (height + min_dim) / 2 | |
image = image.crop((left, top, right, bottom)) | |
transform = transforms.Compose([ | |
transforms.Grayscale(num_output_channels=1), | |
transforms.Resize((28, 28)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
image = transform(image).unsqueeze(0).to(device) | |
return image | |
# Convert tensor to PIL image | |
def tensor_to_pil(tensor): | |
tensor = tensor.squeeze().cpu() | |
tensor = tensor * 0.5 + 0.5 # Unnormalize | |
tensor = tensor.numpy() | |
image = Image.fromarray((tensor * 255).astype('uint8'), mode='L') | |
return image | |
# Get predictions for different transformations | |
def get_predictions(model, image): | |
predictions = [] | |
transformed_images = [] | |
# Original image | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted = torch.max(output.data, 1) | |
predictions.append(predicted.item()) | |
transformed_images.append(("Original", image)) | |
# Mirror image | |
mirror_image = torch.flip(image, [3]) | |
with torch.no_grad(): | |
output = model(mirror_image) | |
_, predicted = torch.max(output.data, 1) | |
predictions.append(predicted.item()) | |
transformed_images.append(("Mirror", mirror_image)) | |
# Rotated images | |
for angle in [90, 180, 270]: | |
rotated_image = torch.rot90(image, k=angle // 90, dims=[2, 3]) | |
with torch.no_grad(): | |
output = model(rotated_image) | |
_, predicted = torch.max(output.data, 1) | |
predictions.append(predicted.item()) | |
transformed_images.append((f"Rotated {angle}°", rotated_image)) | |
return predictions, transformed_images | |
# Streamlit app | |
def main(): | |
st.title("American Sign Language Recognition") | |
image = Image.open("asl.png") | |
st.image(image, caption="American Sign Language", use_container_width=True) | |
# Initialize the image variable | |
image = None | |
# Sidebar for image input | |
option = st.sidebar.selectbox("Choose an option", ("Upload an Image", "Take a Photo")) | |
if option == "Upload an Image": | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert("RGB") | |
elif option == "Take a Photo": | |
image = st.camera_input("Take a photo") | |
if image is not None: | |
image = Image.open(BytesIO(image.getvalue())).convert("RGB") | |
if image is not None: | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
st.write("") | |
# Preprocess the image | |
processed_image = preprocess_image(image) | |
# Load the model | |
model = load_model() | |
# Get predictions | |
predictions, transformed_images = get_predictions(model, processed_image) | |
predicted_classes = [chr(pred + 65) for pred in predictions] | |
# Display transformed images and their predictions | |
for label, img_tensor in transformed_images: | |
img_pil = tensor_to_pil(img_tensor) | |
st.image(img_pil, caption=f"{label} Image", use_container_width=True) | |
st.write(f"Prediction for {label} Image: {chr(predictions[transformed_images.index((label, img_tensor))] + 65)}") | |
# Majority vote | |
final_prediction = Counter(predicted_classes).most_common(1)[0][0] | |
# Print summary and final prediction | |
st.write(f"Predictions: {predicted_classes}") | |
st.write(f"Final Predicted ASL letter: {final_prediction}") | |
if __name__ == "__main__": | |
main() |