athulnambiar commited on
Commit
391ecac
1 Parent(s): 274bdd1

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +185 -0
  2. multi_weight.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import os
8
+ import time
9
+
10
+ ########################
11
+ # MODEL DEFINITION
12
+ ########################
13
+ class MelanomaModel(nn.Module):
14
+ def __init__(self, out_size, dropout_prob=0.5):
15
+ super(MelanomaModel, self).__init__()
16
+ from efficientnet_pytorch import EfficientNet
17
+ self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
18
+ # Remove the original FC layer
19
+ self.efficient_net._fc = nn.Identity()
20
+
21
+ self.fc1 = nn.Linear(1280, 512)
22
+ self.fc2 = nn.Linear(512, 256)
23
+ self.fc3 = nn.Linear(256, out_size)
24
+
25
+ self.dropout = nn.Dropout(dropout_prob)
26
+
27
+ def forward(self, x):
28
+ x = self.efficient_net(x)
29
+ x = x.view(x.size(0), -1)
30
+ x = F.relu(self.fc1(x))
31
+ x = self.dropout(x)
32
+ x = F.relu(self.fc2(x))
33
+ x = self.dropout(x)
34
+ x = self.fc3(x)
35
+ return x
36
+
37
+
38
+ ########################
39
+ # DIAGNOSIS MAP
40
+ ########################
41
+ DIAGNOSIS_MAP = {
42
+ 0: 'Melanoma',
43
+ 1: 'Melanocytic nevus',
44
+ 2: 'Basal cell carcinoma',
45
+ 3: 'Actinic keratosis',
46
+ 4: 'Benign keratosis',
47
+ 5: 'Dermatofibroma',
48
+ 6: 'Vascular lesion',
49
+ 7: 'Squamous cell carcinoma',
50
+ 8: 'Unknown'
51
+ }
52
+
53
+ ########################
54
+ # LOAD MODEL FUNCTION
55
+ ########################
56
+ @st.cache_resource
57
+ def load_model():
58
+ """
59
+ Loads the model checkpoint.
60
+ Using weights_only=False (if you trust the .pth file).
61
+ If you prefer a more secure approach, re-save your checkpoint
62
+ to only contain raw state_dict and set weights_only=True.
63
+ """
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ model = MelanomaModel(out_size=9)
66
+
67
+ # Path to your model file
68
+ model_path = os.path.join("model", "multi_weight.pth")
69
+
70
+ # If you trust the checkpoint file, set weights_only=False
71
+ checkpoint = torch.load(
72
+ model_path,
73
+ map_location=device,
74
+ weights_only=False # if you have a purely raw state_dict, you can use True
75
+ )
76
+ model.load_state_dict(checkpoint["model_state_dict"])
77
+
78
+ model.to(device)
79
+ model.eval()
80
+
81
+ return model, device
82
+
83
+ ########################
84
+ # IMAGE TRANSFORM
85
+ ########################
86
+ transform = transforms.Compose([
87
+ transforms.Resize(256),
88
+ transforms.CenterCrop(224),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize([0.485, 0.456, 0.406],
91
+ [0.229, 0.224, 0.225])
92
+ ])
93
+
94
+ ########################
95
+ # PREDICTION UTILS
96
+ ########################
97
+ def predict_skin_lesion(img: Image.Image, model: nn.Module, device: torch.device):
98
+ # Transform and move image to device
99
+ img_tensor = transform(img).unsqueeze(0).to(device)
100
+
101
+ with torch.no_grad():
102
+ outputs = model(img_tensor)
103
+ probs = F.softmax(outputs, dim=1)
104
+ top_probs, top_idxs = torch.topk(probs, 3, dim=1) # top 3 predictions
105
+
106
+ predictions = []
107
+ for prob, idx in zip(top_probs[0], top_idxs[0]):
108
+ label = DIAGNOSIS_MAP.get(idx.item(), "Unknown")
109
+ confidence = prob.item() * 100
110
+ predictions.append((label, confidence))
111
+
112
+ return predictions
113
+
114
+ ########################
115
+ # PAGE CONFIG & STYLE
116
+ ########################
117
+ st.set_page_config(
118
+ page_title="Skin Lesion Classifier",
119
+ page_icon=":microscope:",
120
+ layout="centered",
121
+ initial_sidebar_state="expanded"
122
+ )
123
+
124
+ def set_background_color():
125
+ st.markdown(
126
+ """
127
+ <style>
128
+ .stApp {
129
+ background-color: #FDEAE0; /* A pale peach/light skin tone */
130
+ }
131
+ </style>
132
+ """,
133
+ unsafe_allow_html=True
134
+ )
135
+
136
+ set_background_color()
137
+
138
+ ########################
139
+ # STREAMLIT APP
140
+ ########################
141
+ def main():
142
+ st.title("Skin Lesion Classifier")
143
+ st.write("Upload an image of a skin lesion to see the top-3 predicted diagnoses.")
144
+
145
+ # Create a stylish sidebar
146
+ st.sidebar.title("Possible Diagnoses")
147
+ st.sidebar.markdown("Here are the categories the model can distinguish:")
148
+ for idx, diag in DIAGNOSIS_MAP.items():
149
+ st.sidebar.markdown(f"- **{diag}**")
150
+
151
+ # Add the names to the sidebar in a new section
152
+ st.sidebar.title("Team Members")
153
+ st.sidebar.markdown(
154
+ """
155
+ - **PRATHUSH MON**
156
+ - **PRATIK J**
157
+ - **RAYAN NASAR**
158
+ - **R HARIMURALI**
159
+ - **WASEEM AHAMMED**
160
+ """
161
+ )
162
+
163
+ # Load the model once (cached)
164
+ model, device = load_model()
165
+
166
+ # File uploader
167
+ uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
168
+
169
+ if uploaded_file is not None:
170
+ # Display the image
171
+ image = Image.open(uploaded_file)
172
+ st.image(image, caption="Uploaded Image", use_column_width=True)
173
+
174
+ # Predict on button click
175
+ if st.button("Classify"):
176
+ with st.spinner("Analyzing..."):
177
+ time.sleep(3) # 3-second spinner
178
+ results = predict_skin_lesion(image, model, device)
179
+
180
+ st.subheader("Top-3 Predictions")
181
+ for i, (diagnosis, confidence) in enumerate(results, start=1):
182
+ st.write(f"{i}. **{diagnosis}**: {confidence:.2f}%")
183
+
184
+ if __name__ == "__main__":
185
+ main()
multi_weight.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79a52c72fc2442a3e5a178c2b47b307b701206f75226b1f0aa6241478a745a66
3
+ size 19614586
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.25.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ efficientnet-pytorch==0.7.1
5
+ Pillow==9.5.0