louiecerv's picture
sync with remote
4601591
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
import os
from collections import Counter
# Hugging Face Hub credentials
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO_ID = "louiecerv/amer_sign_lang_data_augmentation" # Replace with your repo ID
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define new CNN model
IMG_HEIGHT = 28
IMG_WIDTH = 28
IMG_CHS = 1
N_CLASSES = 24 # Adjusted to 24 classes
class MyConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)
def forward(self, x):
return self.model(x)
flattened_img_size = 75 * 3 * 3 # Adjusted based on the final output size
class ASL_CNN(nn.Module):
def __init__(self):
super(ASL_CNN, self).__init__()
self.base_model = nn.Sequential(
MyConvBlock(IMG_CHS, 25, 0), # 25 x 14 x 14
MyConvBlock(25, 50, 0.2), # 50 x 7 x 7
MyConvBlock(50, 75, 0), # 75 x 3 x 3
nn.Flatten(),
nn.Linear(flattened_img_size, 512),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(512, N_CLASSES)
)
def forward(self, x):
return self.base_model(x)
# Load the model from Hugging Face Hub
def load_model():
model = ASL_CNN().to(device)
model_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/pytorch_model.bin"
response = requests.get(model_url)
state_dict = torch.load(BytesIO(response.content), map_location=device)
model.load_state_dict(state_dict, strict=False) # Set strict=False to ignore non-matching keys
model.eval()
return model
# Preprocess the image
def preprocess_image(image):
# Crop the image to a square
width, height = image.size
min_dim = min(width, height)
left = (width - min_dim) / 2
top = (height - min_dim) / 2
right = (width + min_dim) / 2
bottom = (height + min_dim) / 2
image = image.crop((left, top, right, bottom))
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
image = transform(image).unsqueeze(0).to(device)
return image
# Convert tensor to PIL image
def tensor_to_pil(tensor):
tensor = tensor.squeeze().cpu()
tensor = tensor * 0.5 + 0.5 # Unnormalize
tensor = tensor.numpy()
image = Image.fromarray((tensor * 255).astype('uint8'), mode='L')
return image
# Get predictions for different transformations
def get_predictions(model, image):
predictions = []
transformed_images = []
# Original image
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
predictions.append(predicted.item())
transformed_images.append(("Original", image))
# Mirror image
mirror_image = torch.flip(image, [3])
with torch.no_grad():
output = model(mirror_image)
_, predicted = torch.max(output.data, 1)
predictions.append(predicted.item())
transformed_images.append(("Mirror", mirror_image))
# Rotated images
for angle in [90, 180, 270]:
rotated_image = torch.rot90(image, k=angle // 90, dims=[2, 3])
with torch.no_grad():
output = model(rotated_image)
_, predicted = torch.max(output.data, 1)
predictions.append(predicted.item())
transformed_images.append((f"Rotated {angle}°", rotated_image))
return predictions, transformed_images
# Streamlit app
def main():
st.title("American Sign Language Recognition")
image = Image.open("asl.png")
st.image(image, caption="American Sign Language", use_container_width=True)
# Initialize the image variable
image = None
# Sidebar for image input
option = st.sidebar.selectbox("Choose an option", ("Upload an Image", "Take a Photo"))
if option == "Upload an Image":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
elif option == "Take a Photo":
image = st.camera_input("Take a photo")
if image is not None:
image = Image.open(BytesIO(image.getvalue())).convert("RGB")
if image is not None:
st.image(image, caption="Uploaded Image", use_container_width=True)
st.write("")
# Preprocess the image
processed_image = preprocess_image(image)
# Load the model
model = load_model()
# Get predictions
predictions, transformed_images = get_predictions(model, processed_image)
predicted_classes = [chr(pred + 65) for pred in predictions]
# Display transformed images and their predictions
for label, img_tensor in transformed_images:
img_pil = tensor_to_pil(img_tensor)
st.image(img_pil, caption=f"{label} Image", use_container_width=True)
st.write(f"Prediction for {label} Image: {chr(predictions[transformed_images.index((label, img_tensor))] + 65)}")
# Majority vote
final_prediction = Counter(predicted_classes).most_common(1)[0][0]
# Print summary and final prediction
st.write(f"Predictions: {predicted_classes}")
st.write(f"Final Predicted ASL letter: {final_prediction}")
if __name__ == "__main__":
main()