dhhd255 commited on
Commit
36a55af
1 Parent(s): 2da9cbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import numpy as np
6
+ import streamlit as st
7
+
8
+ # Set the device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Load the trained model from the Hugging Face Hub
12
+ model = AutoModel.from_pretrained('dhhd255/parkinsons_pred0.1')
13
+
14
+ # Move the model to the device
15
+ model = model.to(device)
16
+
17
+ # Add custom CSS to use the Inter font, define custom classes for healthy and parkinsons results, increase the font size, make the text bold, and define the footer styles
18
+ st.markdown("""
19
+ <style>
20
+ @import url('https://fonts.googleapis.com/css2?family=Inter&display=swap');
21
+ body {
22
+ font-family: 'Inter', sans-serif;
23
+ }
24
+ .result {
25
+ font-size: 24px;
26
+ font-weight: bold;
27
+ }
28
+ .healthy {
29
+ color: #007E3F;
30
+ }
31
+ .parkinsons {
32
+ color: #C30000;
33
+ }
34
+ .caption_c{
35
+ position: relative;
36
+ display: flex;
37
+ flex-directon: column;
38
+ align-items: center;
39
+ top: calc(99vh - 370px);
40
+ }
41
+ .caption {
42
+
43
+ text-align: center;
44
+ color: #646464;
45
+ font-size: 14px;
46
+ }
47
+ button:hover {
48
+ background-color: lightblue !important;
49
+ outline-color: lightblue !important;
50
+ }
51
+ button:focus {
52
+ background-color: lightblue !important;
53
+ outline-color: lightblue !important;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ st.title("Parkinson's Disease Prediction")
59
+
60
+ uploaded_file = st.file_uploader("Upload your :blue[Spiral] drawing here", type=["png", "jpg", "jpeg"])
61
+ if uploaded_file is not None:
62
+ col1, col2 = st.columns(2)
63
+
64
+ # Load and resize the image
65
+ image_size = (224, 224)
66
+ new_image = Image.open(uploaded_file).convert('RGB').resize(image_size)
67
+ col1.image(new_image, use_column_width=True)
68
+ new_image = np.array(new_image)
69
+ new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0)
70
+
71
+ # Move the data to the device
72
+ new_image = new_image.to(device)
73
+
74
+ # Make predictions using the trained model
75
+ with torch.no_grad():
76
+ predictions = model(new_image)
77
+ logits = predictions.last_hidden_state
78
+ logits = logits.view(logits.shape[0], -1)
79
+ num_classes=2
80
+ feature_reducer = nn.Linear(logits.shape[1], num_classes)
81
+
82
+ logits = logits.to(device)
83
+ feature_reducer = feature_reducer.to(device)
84
+
85
+ logits = feature_reducer(logits)
86
+ predicted_class = torch.argmax(logits, dim=1).item()
87
+ confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
88
+ if(predicted_class == 0):
89
+ col2.markdown('<span class="result parkinsons">Predicted class: Parkinson\'s</span>', unsafe_allow_html=True)
90
+ col2.caption(f'{confidence*100:.0f}% sure')
91
+ else:
92
+ col2.markdown('<span class="result healthy">Predicted class: Healthy</span>', unsafe_allow_html=True)
93
+ col2.caption(f'{confidence*100:.0f}% sure')
94
+
95
+ # Add a caption at the bottom of the page
96
+ st.markdown('<div class="caption_c"><p class="caption">Made with love by Jayant</p></div>', unsafe_allow_html=True)