bhuvaneshprasad commited on
Commit
631d6c6
1 Parent(s): 6e70d14

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +28 -0
  3. model.keras +3 -0
  4. prediction.py +47 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.keras filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app/app.py
2
+
3
+ import streamlit as st
4
+ import requests
5
+ from PIL import Image
6
+
7
+ # Define FastAPI backend URL
8
+ backend_url = "http://localhost:7384"
9
+
10
+ def main():
11
+ st.title("SETI Signals Classifier")
12
+
13
+ # Example: Upload file and send POST request to FastAPI endpoint
14
+ uploaded_file = st.file_uploader("Choose an image to predict...", type=["png"])
15
+ st.markdown("You can get sample images from [here](https://github.com/bhuvaneshprasad/End-to-End-SETI-Classification-using-CNN-MLFlow-DVC/tree/main/assets/test_images) to predict.")
16
+ if uploaded_file is not None:
17
+ with st.spinner('Predicting...'):
18
+ files = {"file": uploaded_file}
19
+ response = requests.post(f"{backend_url}/predict", files=files)
20
+ if response.status_code == 200:
21
+ st.json(response.json())
22
+ image = Image.open(uploaded_file)
23
+ st.image(image, caption="Uploaded Image", use_column_width=True)
24
+ else:
25
+ st.error("Failed to predict")
26
+
27
+ if __name__ == "__main__":
28
+ main()
model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2bb6494fb9ef5a65353aafea4100b93620a51671b2430fc7b6fc17f9145ff31
3
+ size 658923582
prediction.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from dotenv import load_dotenv
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ load_dotenv()
8
+
9
+ class PredictionPipeline:
10
+ """
11
+ A class representing a pipeline for making predictions using a pre-trained model.
12
+
13
+ Attributes:
14
+ filename (str): The filename of the image to predict.
15
+
16
+ Methods:
17
+ predict() -> int:
18
+ Loads a pre-trained model, processes an image, and predicts its class.
19
+ """
20
+ def __init__(self,filename):
21
+ """
22
+ Initialize the PredictionPipeline class.
23
+
24
+ Args:
25
+ filename (str): The filename of the image to predict.
26
+ """
27
+ self.filename =filename
28
+
29
+ def predict(self) -> int:
30
+ """
31
+ Perform prediction on the image specified by the filename.
32
+
33
+ Returns:
34
+ int: The predicted class label.
35
+ """
36
+ model = tf.keras.models.load_model(os.path.join(os.getcwd(),Path(os.getenv('MODEL_URI'))))
37
+ class_labels = ['brightpixel','narrowband',
38
+ 'narrowbanddrd','noise',
39
+ 'squarepulsednarrowband','squiggle',
40
+ 'squigglesquarepulsednarrowband']
41
+
42
+ imagename = self.filename
43
+ test_image = tf.keras.preprocessing.image.load_img(imagename, target_size = (256,256))
44
+ test_image = tf.keras.preprocessing.image.img_to_array(test_image)
45
+ test_image = np.expand_dims(test_image, axis = 0)
46
+ result = np.argmax(model.predict(test_image), axis=1)
47
+ return class_labels[int(result)]