iBrokeTheCode commited on
Commit
af98bbd
Β·
1 Parent(s): 3f0ccc8

feat: Add predict functionality

Browse files
Files changed (3) hide show
  1. requirements.txt +4 -1
  2. src/predictor.py +50 -0
  3. src/streamlit_app.py +8 -4
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- streamlit
 
 
 
 
1
+ numpy==2.3.2
2
+ pillow==11.3.0
3
+ streamlit==1.48.1
4
+ tensorflow==2.20.0
src/predictor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tensorflow.keras.applications.resnet50 import (
3
+ ResNet50,
4
+ decode_predictions,
5
+ preprocess_input,
6
+ )
7
+ from tensorflow.keras.preprocessing import image
8
+
9
+ # Load the model outside the function to ensure it's loaded only once
10
+ model = ResNet50(include_top=True, weights="imagenet")
11
+
12
+
13
+ def predict_image(img):
14
+ """
15
+ Preprocesses an image and runs a pre-trained ResNet50 model to get a prediction.
16
+
17
+ Parameters
18
+ ----------
19
+ img : PIL.Image
20
+ The image object to classify.
21
+
22
+ Returns
23
+ -------
24
+ class_name, pred_probability : tuple(str, float)
25
+ The model's predicted class as a string and the corresponding confidence
26
+ score as a number.
27
+ """
28
+ # Resize the image to match model input dimensions (224, 224)
29
+ img = img.resize((224, 224))
30
+
31
+ # Convert Pillow image to np.array
32
+ x = image.img_to_array(img)
33
+
34
+ # Add an extra dimension for the batch size
35
+ x_batch = np.expand_dims(x, axis=0)
36
+
37
+ # Apply ResNet50-specific preprocessing
38
+ x_batch = preprocess_input(x_batch)
39
+
40
+ # Make predictions
41
+ predictions = model.predict(x_batch, verbose=0)
42
+
43
+ # Get predictions using model methods and decode predictions
44
+ top_pred = decode_predictions(predictions, top=1)[0][0] # imagenet_id, label, score
45
+ _, class_name, pred_probability = top_pred
46
+
47
+ # Convert probability to float and round it
48
+ pred_probability = round(float(pred_probability), 4)
49
+
50
+ return class_name, pred_probability
src/streamlit_app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  from PIL import Image
3
 
 
 
4
  # πŸ“Œ PAGE SETUP
5
  st.set_page_config(page_title="Image Classifier App", page_icon="πŸ€–", layout="centered")
6
  st.html("""
@@ -79,7 +81,11 @@ with st.container():
79
  if classify_button:
80
  # Check if an image is selected before running prediction
81
  if uploaded_image is not None:
82
- st.session_state["selected_image"] = uploaded_image
 
 
 
 
83
  elif selected_example:
84
  # Load the selected example image
85
  try:
@@ -96,8 +102,6 @@ with st.container():
96
  st.session_state["selected_image"],
97
  caption="Image to be classified",
98
  )
99
- st.markdown("---")
100
- st.subheader("Prediction")
101
 
102
  # Call the prediction function and display results
103
  with st.spinner("Analyzing image..."):
@@ -109,7 +113,7 @@ with st.container():
109
 
110
  st.metric(
111
  label="Prediction",
112
- value=f"Prediction: {predicted_label.replace('_', ' ').title()}",
113
  delta=f"{predicted_score * 100:.2f}%",
114
  help="The predicted category and its confidence score.",
115
  delta_color="normal",
 
1
  import streamlit as st
2
  from PIL import Image
3
 
4
+ from predictor import predict_image
5
+
6
  # πŸ“Œ PAGE SETUP
7
  st.set_page_config(page_title="Image Classifier App", page_icon="πŸ€–", layout="centered")
8
  st.html("""
 
81
  if classify_button:
82
  # Check if an image is selected before running prediction
83
  if uploaded_image is not None:
84
+ # st.session_state["selected_image"] = uploaded_image
85
+ # Use Image.open() to convert the UploadedFile object into a PIL.Image object
86
+ st.session_state["selected_image"] = Image.open(uploaded_image)
87
+ st.session_state["uploaded_file"] = uploaded_image
88
+
89
  elif selected_example:
90
  # Load the selected example image
91
  try:
 
102
  st.session_state["selected_image"],
103
  caption="Image to be classified",
104
  )
 
 
105
 
106
  # Call the prediction function and display results
107
  with st.spinner("Analyzing image..."):
 
113
 
114
  st.metric(
115
  label="Prediction",
116
+ value=f"{predicted_label.replace('_', ' ').title()}",
117
  delta=f"{predicted_score * 100:.2f}%",
118
  help="The predicted category and its confidence score.",
119
  delta_color="normal",