|
|
|
from prediction import PredictionPipeline |
|
import streamlit as st |
|
import requests |
|
from PIL import Image |
|
import os |
|
|
|
def main(): |
|
st.title("SETI Signals Classifier") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image to predict...", type=["png"]) |
|
st.markdown("You can get sample images from [here](https://github.com/bhuvaneshprasad/End-to-End-SETI-Classification-using-CNN-MLFlow-DVC/tree/main/assets/test_images) to predict.") |
|
if uploaded_file is not None: |
|
with st.spinner('Predicting...'): |
|
files = {"file": uploaded_file} |
|
with open("temp_image.png", "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
try: |
|
predictor = PredictionPipeline("temp_image.png") |
|
prediction = predictor.predict() |
|
finally: |
|
os.remove("temp_image.png") |
|
if type(prediction) == str: |
|
st.json({'prediction' : prediction}) |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
else: |
|
st.error("Failed to predict") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|