iBrokeTheCode commited on
Commit
a2c2878
Β·
1 Parent(s): 4edfc37

fix: Clear session state after predicting images

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. src/streamlit_app.py +58 -59
Dockerfile CHANGED
@@ -14,7 +14,7 @@ COPY src/ ./src/
14
  RUN pip3 install -r requirements.txt
15
 
16
  # Set Streamlit config/data folder to a writable path
17
- ENV STREAMLIT_HOME=/app/.streamlit
18
  RUN mkdir -p /app/.streamlit
19
 
20
  EXPOSE 8501
 
14
  RUN pip3 install -r requirements.txt
15
 
16
  # Set Streamlit config/data folder to a writable path
17
+ ENV HOME=/app
18
  RUN mkdir -p /app/.streamlit
19
 
20
  EXPOSE 8501
src/streamlit_app.py CHANGED
@@ -19,7 +19,6 @@ st.html("""
19
  """)
20
 
21
  # πŸ“Œ INITIALIZE SESSION STATE
22
- # We initialize session state variables to manage app state
23
  if "uploaded_image" not in st.session_state:
24
  st.session_state["uploaded_image"] = None
25
  if "example_selected" not in st.session_state:
@@ -55,12 +54,6 @@ with st.container():
55
  key="image_uploader",
56
  )
57
 
58
- # Update state when a new file is uploaded
59
- if uploaded_file is not st.session_state.uploaded_image:
60
- st.session_state.uploaded_image = uploaded_file
61
- st.session_state.example_selected = False
62
- st.session_state.prediction_result = None
63
-
64
  st.html("<br>")
65
  st.subheader("Or Try an Example", divider=True)
66
 
@@ -68,7 +61,7 @@ with st.container():
68
  selected_example = st.segmented_control(
69
  label="Categories",
70
  options=["Animal", "Vehicle", "Object", "Building"],
71
- default="Animal",
72
  help="Select one of the pre-loaded examples",
73
  )
74
 
@@ -82,64 +75,70 @@ with st.container():
82
  icon="✨",
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # πŸ“Œ PREDICTION RESULTS
86
  with col_results:
87
  st.header("Results", divider=True)
88
 
89
- image_to_process = None
90
-
91
- # Logic to handle which image to display
92
- if st.session_state.uploaded_image:
93
- # Get the image from the uploaded file
94
- image_to_process = Image.open(st.session_state.uploaded_image)
95
- elif selected_example:
96
- # Load the selected example image using a robust path
97
- try:
98
- img_path = os.path.join(
99
- ASSETS_DIR, f"{selected_example.lower()}.jpg"
100
- )
101
- image_to_process = Image.open(img_path)
102
- except FileNotFoundError:
103
- st.error(
104
- f"Error: The example image '{selected_example.lower()}.jpg' was not found."
105
- )
106
- st.stop()
107
-
108
- # Display image and run prediction when button is clicked
109
  if image_to_process:
110
  st.image(image_to_process, caption="Image to be classified")
111
 
112
- if classify_button:
113
- # Run the prediction logic
114
- with st.spinner("Analyzing image..."):
115
- try:
116
- # πŸ“Œ Prediction function call πŸ“Œ
117
- from predictor import predict_image
118
-
119
- predicted_label, predicted_score = predict_image(
120
- image_to_process
121
- )
122
- st.session_state.prediction_result = {
123
- "label": predicted_label.replace("_", " ").title(),
124
- "score": predicted_score,
125
- }
126
- except Exception as e:
127
- st.error(f"An error occurred during prediction: {e}")
128
-
129
- # Display the prediction result if available
130
- if st.session_state.prediction_result:
131
- st.metric(
132
- label="Prediction",
133
- value=st.session_state.prediction_result["label"],
134
- delta=f"{st.session_state.prediction_result['score'] * 100:.2f}%",
135
- help="The predicted category and its confidence score.",
136
- delta_color="normal",
137
- )
138
- st.balloons()
139
- else:
140
- st.info("Click 'Classify Image' to see the prediction.")
141
- else:
142
- st.info("Choose an image to get a prediction.")
143
 
144
  # πŸ“Œ DESCRIPTION TAB
145
  with tab_description:
 
19
  """)
20
 
21
  # πŸ“Œ INITIALIZE SESSION STATE
 
22
  if "uploaded_image" not in st.session_state:
23
  st.session_state["uploaded_image"] = None
24
  if "example_selected" not in st.session_state:
 
54
  key="image_uploader",
55
  )
56
 
 
 
 
 
 
 
57
  st.html("<br>")
58
  st.subheader("Or Try an Example", divider=True)
59
 
 
61
  selected_example = st.segmented_control(
62
  label="Categories",
63
  options=["Animal", "Vehicle", "Object", "Building"],
64
+ default=None,
65
  help="Select one of the pre-loaded examples",
66
  )
67
 
 
75
  icon="✨",
76
  )
77
 
78
+ # --- LOGIC FOR IMAGE SELECTION & PREDICTION ---
79
+ # Clear the previous prediction result if a new input is selected
80
+ if uploaded_file or selected_example:
81
+ st.session_state.prediction_result = None
82
+
83
+ image_to_process = None
84
+
85
+ if uploaded_file:
86
+ image_to_process = Image.open(uploaded_file)
87
+
88
+ elif selected_example:
89
+ try:
90
+ img_path = os.path.join(
91
+ APP_DIR, "assets", f"{selected_example.lower()}.jpg"
92
+ )
93
+ image_to_process = Image.open(img_path)
94
+ except FileNotFoundError:
95
+ st.error(
96
+ f"Error: The example image '{selected_example.lower()}.jpg' was not found."
97
+ )
98
+ st.stop()
99
+
100
  # πŸ“Œ PREDICTION RESULTS
101
  with col_results:
102
  st.header("Results", divider=True)
103
 
104
+ # Display a "get started" message if no image is selected
105
+ if not image_to_process and not st.session_state.prediction_result:
106
+ st.info("Choose an image or an example to get a prediction.")
107
+
108
+ # Display the image if one is selected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if image_to_process:
110
  st.image(image_to_process, caption="Image to be classified")
111
 
112
+ # If the button is clicked, run the prediction logic
113
+ if classify_button and image_to_process:
114
+ with st.spinner("Analyzing image..."):
115
+ try:
116
+ from predictor import predict_image
117
+
118
+ predicted_label, predicted_score = predict_image(
119
+ image_to_process
120
+ )
121
+ st.session_state.prediction_result = {
122
+ "label": predicted_label.replace("_", " ").title(),
123
+ "score": predicted_score,
124
+ }
125
+ except Exception as e:
126
+ st.error(f"An error occurred during prediction: {e}")
127
+
128
+ # Display the prediction result if available in session state
129
+ if st.session_state.prediction_result:
130
+ st.metric(
131
+ label="Prediction",
132
+ value=st.session_state.prediction_result["label"],
133
+ delta=f"{st.session_state.prediction_result['score'] * 100:.2f}%",
134
+ help="The predicted category and its confidence score.",
135
+ delta_color="normal",
136
+ )
137
+ st.balloons()
138
+
139
+ elif image_to_process:
140
+ st.info("Click 'Classify Image' to see the prediction.")
141
+
 
142
 
143
  # πŸ“Œ DESCRIPTION TAB
144
  with tab_description: