braedenb's picture
Update app.py
9fa89e8 verified
from hf_setup import setup
setup()
import requests
import json
import streamlit as st
import numpy as np
import pandas as pd
from PIL import Image
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import tensorflow as tf
import os
from google.cloud import storage
import uuid
from datetime import datetime
import io
from datetime import date
from google.oauth2 import service_account
import gspread
email=""
st.set_page_config(layout="wide")
# Load the bird info
# contains scientific_name, common_name, and wikipedia_link
@st.cache_data
def load_bird_info():
return pd.read_csv("bird_info.csv")
# Use the function to load bird info
bird_info_df = load_bird_info()
# Download model from private GitHub repo
@st.cache_resource()
def download_and_load_model():
# GitHub URL
url = st.secrets["GITHUB_API_URL"]
# GitHub PAT
headers = {"Authorization": "Bearer " + st.secrets["GITHUB_PAT"]}
# Send the API request
r = requests.get(url, headers=headers)
download_url = r.json()["download_url"]
# download the file
model_response = requests.get(download_url, headers=headers, stream=True)
# Save the file to a temporary location
with open("/tmp/model.h5", "wb") as f:
f.write(model_response.content)
# Load the model from the temporary location
model = load_model("/tmp/model.h5")
return model
# Load the saved model
model = download_and_load_model()
# OLD METHOD OF LOAD/SAVE
# Use a cache so app doesn't have to reload model everytime
# @st.cache_resource()
# def load_cached_model(model_path):
# return load_model(model_path)
#model = load_cached_model('...') # Update the model path as needed
# Load the class indices
# Remake the .json file everytime you replace the model !!!
with open('class_indices_c271.json', 'r') as f:
class_indices = json.load(f)
def preprocess_and_predict(image, top_n=3):
image = image.resize((299, 299))
img_array = img_to_array(image)
img_array = img_array / 255.0
img_array = np.expand_dims(img_array, axis=0)
preds = model.predict(img_array)
# Get top_n predicted indices
predicted_indices = np.argpartition(preds[0], -top_n)[-top_n:]
predicted_indices = predicted_indices[np.argsort(preds[0][predicted_indices])][::-1]
# Convert the predicted indices to class names and confidences
predicted_classes_and_confidences = []
for predicted_index in predicted_indices:
predicted_class = None
for class_name, index in class_indices.items():
if index == predicted_index:
predicted_class = class_name
break
confidence = preds[0][predicted_index] * 100
if confidence < 0.1:
confidence_str = "<0.1%"
else:
confidence_str = f"{confidence:.2f}%"
predicted_classes_and_confidences.append((predicted_class, confidence_str))
return predicted_classes_and_confidences
def load_css(file_name: str):
with open(file_name, "r") as f:
css = f.read()
return f"<style>{css}</style>"
# Load the CSS file
css_styles = load_css("styles.css")
class_names = list(bird_info_df["scientific_name"].values)
def save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url):
json_creds_str = st.secrets["GOOGLE_SHEETS_JSON_CREDS"]
json_creds_dict = json.loads(json_creds_str)
# Load the service account credentials from the JSON string stored in secret
credentials = service_account.Credentials.from_service_account_info(
json_creds_dict,
scopes=['https://www.googleapis.com/auth/spreadsheets']
)
client = gspread.authorize(credentials)
# Open Google Sheet
sheet = client.open_by_key(sheet_id).worksheet(sheet_name)
next_row = len(sheet.get_all_values()) + 1
data = [email]
for prediction, confidence in predictions:
bird_info_filtered = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == prediction.lower()]
bird_info = bird_info_filtered.iloc[0]
common_name = bird_info["common_name"]
data.append(common_name)
data.append(confidence)
# Add GCS url
data.append(gcs_url)
# Add current date to data
current_date = date.today().strftime('%m/%d/%Y')
data.append(current_date)
# Send data to Sheet
sheet.insert_row(data, next_row)
def upload_to_gcs(bucket_name, source_file, destination_blob_name):
# Creates a client using the credentials stored in Streamlit secrets
json_creds_dict = json.loads(st.secrets["GCS_JSON_CREDS"])
credentials = service_account.Credentials.from_service_account_info(json_creds_dict)
client = storage.Client(credentials=credentials, project=st.secrets['GCS_PROJECT_ID'])
# Get the bucket
bucket = client.get_bucket(bucket_name)
# Making filenames unique
extension = os.path.splitext(destination_blob_name)[1]
new_file_name = datetime.now().strftime('%Y%m%d') + '-' + str(uuid.uuid4()) + extension
destination_blob_name = f'testing-phase/{new_file_name}'
# Create a new blob and upload the file's content
blob = bucket.blob(destination_blob_name)
blob.upload_from_file(source_file, content_type=uploaded_file.type)
# Construct the public URL
public_url = f"https://storage.googleapis.com/{bucket.name}/{blob.name}"
# Return the URL to the uploaded file
return public_url
st.title("Health & Wellness Bird Identifier")
st.write("Upload an image of a bird to classify it!")
# Get user email
# email = st.text_input("Enter your email *(Optional)*")
uploaded_file = st.file_uploader("Choose an image of a bird...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Create a copy of the uploaded file's contents
uploaded_file_content = uploaded_file.read()
image = Image.open(uploaded_file)
st.write("")
# Create two columns
# resize image so it doesn't fill page
col1, col2 = st.columns(2)
resized_image = image.resize((224, 224))
col1.image(resized_image, use_column_width=True)
# Move the Classify button and results to the right column
with col2:
st.write("") # Add an empty line to create some space
# Add custom CSS to style the Streamlit app
st.markdown(css_styles, unsafe_allow_html=True)
if st.button("Classify"):
# Create a BytesIO object from the copied content
uploaded_file_copy = io.BytesIO(uploaded_file_content)
# Use the copy for uploading
# gcs_url = upload_to_gcs(st.secrets["GCS_BUCKET_ID"], uploaded_file_copy, f'testing-phase/{uploaded_file.name}')
with st.spinner("Classifying..."):
predictions = preprocess_and_predict(image)
top_prediction, top_confidence = predictions[0]
bird_info = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == top_prediction.lower()].iloc[0]
common_name = bird_info["common_name"]
wikipedia_link = bird_info["wikipedia_link"]
st.markdown(f'<p class="big-font">The predicted species is: <b>{common_name}</b> (Confidence: {top_confidence})</p>', unsafe_allow_html=True)
st.markdown(f'<p class="big-font">Scientific name: <i>{top_prediction}</i></p>', unsafe_allow_html=True)
st.markdown(f'<p class="big-font">More information: <a href="{wikipedia_link}">Wikipedia</a></p>', unsafe_allow_html=True)
st.write("---") # Markdown syntax for a horizontal line
with st.expander("Your species may also be:"):
for prediction, confidence in predictions[1:]:
bird_info_filtered = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == prediction.lower()]
if not bird_info_filtered.empty:
bird_info = bird_info_filtered.iloc[0]
common_name = bird_info["common_name"]
st.write(f"**{common_name}** (*{prediction}*): {confidence}")
else:
st.write(f"Could not find additional information for {prediction} ({confidence})")
st.write(f"**Debug Info**: The bird species '{prediction}' is not present in the DataFrame.")
# Write to Google Sheets
sheet_id = st.secrets["SHEET_ID"]
sheet_name = st.secrets["SHEET_NAME"]
# TEMP UPLOAD PROCEDURE
# Uploading to GCS avoided until security is better
# Placeholder image is uploaded to sheet for now to track usage of app
gcs_url = "https://inaturalist-open-data.s3.amazonaws.com/photos/363523152/large.jpg"
save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url)
email = email if email else ""
# UNCOMMENT THIS LINE TO UPLOAD REAL BIRD PICS; UNCOMMENT THE gcs_url variable with upload_to_gcs() as well
# save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url)