import torch from transformers import AutoModel import torch.nn as nn from PIL import Image import numpy as np import streamlit as st # Set the device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the trained model from the Hugging Face Hub model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred') # Move the model to the device model = model.to(device) # Add custom CSS to use the Inter font, define custom classes for healthy and parkinsons results, increase the font size, make the text bold, and define the footer styles st.markdown(""" """, unsafe_allow_html=True) st.title("Parkinson's Disease Prediction") uploaded_file = st.file_uploader("Upload your :blue[Spiral] drawing here", type=["png", "jpg", "jpeg"]) st.empty() if uploaded_file is not None: col1, col2 = st.columns(2) # Load and resize the image image_size = (224, 224) new_image = Image.open(uploaded_file).convert('RGB').resize(image_size) col1.image(new_image, width=255) new_image = np.array(new_image) new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0) # Move the data to the device new_image = new_image.to(device) # Make predictions using the trained model with torch.no_grad(): predictions = model(new_image) logits = predictions.last_hidden_state logits = logits.view(logits.shape[0], -1) num_classes=2 feature_reducer = nn.Linear(logits.shape[1], num_classes) logits = logits.to(device) feature_reducer = feature_reducer.to(device) logits = feature_reducer(logits) predicted_class = torch.argmax(logits, dim=1).item() confidence = torch.softmax(logits, dim=1)[0][predicted_class].item() if(predicted_class == 0): col2.markdown('Predicted class: Parkinson\'s', unsafe_allow_html=True) col2.caption(f'{confidence*100:.0f}% sure') else: col2.markdown('Predicted class: Healthy', unsafe_allow_html=True) col2.caption(f'{confidence*100:.0f}% sure') uploaded_file = st.file_uploader("Upload your :blue[Wave] drawing here", type=["png", "jpg", "jpeg"]) st.empty() st.empty() st.empty() st.empty() st.empty() st.empty() st.empty() st.empty() st.markdown("""
""", unsafe_allow_html=True)