Spaces:
Running
Running
from hf_setup import setup | |
setup() | |
import requests | |
import json | |
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
from tensorflow.keras.preprocessing.image import img_to_array | |
from tensorflow.keras.models import load_model | |
import tensorflow as tf | |
import os | |
from google.cloud import storage | |
import uuid | |
from datetime import datetime | |
import io | |
from datetime import date | |
from google.oauth2 import service_account | |
import gspread | |
email="" | |
st.set_page_config(layout="wide") | |
# Load the bird info | |
# contains scientific_name, common_name, and wikipedia_link | |
def load_bird_info(): | |
return pd.read_csv("bird_info.csv") | |
# Use the function to load bird info | |
bird_info_df = load_bird_info() | |
# Download model from private GitHub repo | |
def download_and_load_model(): | |
# GitHub URL | |
url = st.secrets["GITHUB_API_URL"] | |
# GitHub PAT | |
headers = {"Authorization": "Bearer " + st.secrets["GITHUB_PAT"]} | |
# Send the API request | |
r = requests.get(url, headers=headers) | |
download_url = r.json()["download_url"] | |
# download the file | |
model_response = requests.get(download_url, headers=headers, stream=True) | |
# Save the file to a temporary location | |
with open("/tmp/model.h5", "wb") as f: | |
f.write(model_response.content) | |
# Load the model from the temporary location | |
model = load_model("/tmp/model.h5") | |
return model | |
# Load the saved model | |
model = download_and_load_model() | |
# OLD METHOD OF LOAD/SAVE | |
# Use a cache so app doesn't have to reload model everytime | |
# @st.cache_resource() | |
# def load_cached_model(model_path): | |
# return load_model(model_path) | |
#model = load_cached_model('...') # Update the model path as needed | |
# Load the class indices | |
# Remake the .json file everytime you replace the model !!! | |
with open('class_indices_c271.json', 'r') as f: | |
class_indices = json.load(f) | |
def preprocess_and_predict(image, top_n=3): | |
image = image.resize((299, 299)) | |
img_array = img_to_array(image) | |
img_array = img_array / 255.0 | |
img_array = np.expand_dims(img_array, axis=0) | |
preds = model.predict(img_array) | |
# Get top_n predicted indices | |
predicted_indices = np.argpartition(preds[0], -top_n)[-top_n:] | |
predicted_indices = predicted_indices[np.argsort(preds[0][predicted_indices])][::-1] | |
# Convert the predicted indices to class names and confidences | |
predicted_classes_and_confidences = [] | |
for predicted_index in predicted_indices: | |
predicted_class = None | |
for class_name, index in class_indices.items(): | |
if index == predicted_index: | |
predicted_class = class_name | |
break | |
confidence = preds[0][predicted_index] * 100 | |
if confidence < 0.1: | |
confidence_str = "<0.1%" | |
else: | |
confidence_str = f"{confidence:.2f}%" | |
predicted_classes_and_confidences.append((predicted_class, confidence_str)) | |
return predicted_classes_and_confidences | |
def load_css(file_name: str): | |
with open(file_name, "r") as f: | |
css = f.read() | |
return f"<style>{css}</style>" | |
# Load the CSS file | |
css_styles = load_css("styles.css") | |
class_names = list(bird_info_df["scientific_name"].values) | |
def save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url): | |
json_creds_str = st.secrets["GOOGLE_SHEETS_JSON_CREDS"] | |
json_creds_dict = json.loads(json_creds_str) | |
# Load the service account credentials from the JSON string stored in secret | |
credentials = service_account.Credentials.from_service_account_info( | |
json_creds_dict, | |
scopes=['https://www.googleapis.com/auth/spreadsheets'] | |
) | |
client = gspread.authorize(credentials) | |
# Open Google Sheet | |
sheet = client.open_by_key(sheet_id).worksheet(sheet_name) | |
next_row = len(sheet.get_all_values()) + 1 | |
data = [email] | |
for prediction, confidence in predictions: | |
bird_info_filtered = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == prediction.lower()] | |
bird_info = bird_info_filtered.iloc[0] | |
common_name = bird_info["common_name"] | |
data.append(common_name) | |
data.append(confidence) | |
# Add GCS url | |
data.append(gcs_url) | |
# Add current date to data | |
current_date = date.today().strftime('%m/%d/%Y') | |
data.append(current_date) | |
# Send data to Sheet | |
sheet.insert_row(data, next_row) | |
def upload_to_gcs(bucket_name, source_file, destination_blob_name): | |
# Creates a client using the credentials stored in Streamlit secrets | |
json_creds_dict = json.loads(st.secrets["GCS_JSON_CREDS"]) | |
credentials = service_account.Credentials.from_service_account_info(json_creds_dict) | |
client = storage.Client(credentials=credentials, project=st.secrets['GCS_PROJECT_ID']) | |
# Get the bucket | |
bucket = client.get_bucket(bucket_name) | |
# Making filenames unique | |
extension = os.path.splitext(destination_blob_name)[1] | |
new_file_name = datetime.now().strftime('%Y%m%d') + '-' + str(uuid.uuid4()) + extension | |
destination_blob_name = f'testing-phase/{new_file_name}' | |
# Create a new blob and upload the file's content | |
blob = bucket.blob(destination_blob_name) | |
blob.upload_from_file(source_file, content_type=uploaded_file.type) | |
# Construct the public URL | |
public_url = f"https://storage.googleapis.com/{bucket.name}/{blob.name}" | |
# Return the URL to the uploaded file | |
return public_url | |
st.title("Health & Wellness Bird Identifier") | |
st.write("Upload an image of a bird to classify it!") | |
# Get user email | |
# email = st.text_input("Enter your email *(Optional)*") | |
uploaded_file = st.file_uploader("Choose an image of a bird...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Create a copy of the uploaded file's contents | |
uploaded_file_content = uploaded_file.read() | |
image = Image.open(uploaded_file) | |
st.write("") | |
# Create two columns | |
# resize image so it doesn't fill page | |
col1, col2 = st.columns(2) | |
resized_image = image.resize((224, 224)) | |
col1.image(resized_image, use_column_width=True) | |
# Move the Classify button and results to the right column | |
with col2: | |
st.write("") # Add an empty line to create some space | |
# Add custom CSS to style the Streamlit app | |
st.markdown(css_styles, unsafe_allow_html=True) | |
if st.button("Classify"): | |
# Create a BytesIO object from the copied content | |
uploaded_file_copy = io.BytesIO(uploaded_file_content) | |
# Use the copy for uploading | |
# gcs_url = upload_to_gcs(st.secrets["GCS_BUCKET_ID"], uploaded_file_copy, f'testing-phase/{uploaded_file.name}') | |
with st.spinner("Classifying..."): | |
predictions = preprocess_and_predict(image) | |
top_prediction, top_confidence = predictions[0] | |
bird_info = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == top_prediction.lower()].iloc[0] | |
common_name = bird_info["common_name"] | |
wikipedia_link = bird_info["wikipedia_link"] | |
st.markdown(f'<p class="big-font">The predicted species is: <b>{common_name}</b> (Confidence: {top_confidence})</p>', unsafe_allow_html=True) | |
st.markdown(f'<p class="big-font">Scientific name: <i>{top_prediction}</i></p>', unsafe_allow_html=True) | |
st.markdown(f'<p class="big-font">More information: <a href="{wikipedia_link}">Wikipedia</a></p>', unsafe_allow_html=True) | |
st.write("---") # Markdown syntax for a horizontal line | |
with st.expander("Your species may also be:"): | |
for prediction, confidence in predictions[1:]: | |
bird_info_filtered = bird_info_df.loc[bird_info_df["scientific_name"].str.lower() == prediction.lower()] | |
if not bird_info_filtered.empty: | |
bird_info = bird_info_filtered.iloc[0] | |
common_name = bird_info["common_name"] | |
st.write(f"**{common_name}** (*{prediction}*): {confidence}") | |
else: | |
st.write(f"Could not find additional information for {prediction} ({confidence})") | |
st.write(f"**Debug Info**: The bird species '{prediction}' is not present in the DataFrame.") | |
# Write to Google Sheets | |
sheet_id = st.secrets["SHEET_ID"] | |
sheet_name = st.secrets["SHEET_NAME"] | |
# TEMP UPLOAD PROCEDURE | |
# Uploading to GCS avoided until security is better | |
# Placeholder image is uploaded to sheet for now to track usage of app | |
gcs_url = "https://inaturalist-open-data.s3.amazonaws.com/photos/363523152/large.jpg" | |
save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url) | |
email = email if email else "" | |
# UNCOMMENT THIS LINE TO UPLOAD REAL BIRD PICS; UNCOMMENT THE gcs_url variable with upload_to_gcs() as well | |
# save_to_google_sheet(email, predictions, sheet_id, sheet_name, gcs_url) | |