creative2 / app.py
artintel235's picture
Update app.py
c62b281 verified
import streamlit as st
import firebase_admin
from firebase_admin import credentials, auth, db, storage
import os
import json
import requests
from io import BytesIO
from PIL import Image
import tempfile
import mimetypes
import uuid
import io
# Load Firebase credentials from Hugging Face Secrets
firebase_creds = os.getenv("FIREBASE_CREDENTIALS")
FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY")
FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET")
if firebase_creds:
firebase_creds = json.loads(firebase_creds)
else:
st.error("Firebase credentials not found. Please check your secrets.")
# Initialize Firebase (only once)
if not firebase_admin._apps:
cred = credentials.Certificate(firebase_creds)
firebase_admin.initialize_app(cred, {
'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/',
'storageBucket': FIREBASE_STORAGE_BUCKET
})
# Initialize session state
if "logged_in" not in st.session_state:
st.session_state.logged_in = False
if "current_user" not in st.session_state:
st.session_state.current_user = None
if "display_name" not in st.session_state:
st.session_state.display_name = None
if "window_size" not in st.session_state:
st.session_state.window_size = 5
if "current_window_start" not in st.session_state:
st.session_state.current_window_start = 0
if "selected_image" not in st.session_state:
st.session_state.selected_image = None
TOKEN = os.getenv("TOKEN0")
API_URL = os.getenv("API_URL")
token_id = 0
tokens_tried = 0
no_of_accounts = 7
model_id = os.getenv("MODEL_ID")
def send_verification_email(id_token):
url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}'
headers = {'Content-Type': 'application/json'}
data = {
'requestType': 'VERIFY_EMAIL',
'idToken': id_token
}
response = requests.post(url, headers=headers, json=data)
result = response.json()
if 'error' in result:
return {'status': 'error', 'message': result['error']['message']}
else:
return {'status': 'success', 'email': result['email']}
# Callback for registration
def register_callback():
email = st.session_state.reg_email
password = st.session_state.reg_password
display_name = st.session_state.reg_display_name
try:
# Step 1: Create a new user in Firebase
user = auth.create_user(email=email, password=password)
# Step 2: Update the user profile with the display name
auth.update_user(user.uid, display_name=display_name)
st.success("Registration successful! Sending verification email...")
# Step 3: Sign in the user programmatically to get the id_token
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
data = {
'email': email,
'password': password,
'returnSecureToken': True
}
response = requests.post(url, json=data)
result = response.json()
if 'idToken' in result:
id_token = result['idToken']
st.session_state.id_token = id_token
verification_result = send_verification_email(id_token)
if verification_result['status'] == 'success':
st.success(f"Verification email sent to {email}.")
else:
st.error(f"Failed to send verification email: {verification_result['message']}")
else:
st.error(f"Failed to retrieve id_token: {result['error']['message']}")
except Exception as e:
st.error(f"Registration failed: {e}")
# Callback for login
def login_callback():
login_identifier = st.session_state.login_identifier
password = st.session_state.login_password
try:
# Try to sign in the user programmatically to check the password validity
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
data = {
'email': login_identifier,
'password': password,
'returnSecureToken': True
}
response = requests.post(url, json=data)
result = response.json()
if 'idToken' in result:
# If sign in was successful, then use email to fetch the user
user = auth.get_user_by_email(login_identifier)
st.session_state.logged_in = True
st.session_state.current_user = user.uid
st.session_state.display_name = user.display_name # Store the display name
st.success("Logged in successfully!")
elif 'error' in result:
# If sign-in fails, retrieve user using display name
try:
user_list = auth.list_users()
for user_info in user_list.users:
if user_info.display_name == login_identifier:
user = user_info
# If user is found using display name, try signing in using email
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}'
data = {
'email': user.email,
'password': password,
'returnSecureToken': True
}
response = requests.post(url, json=data)
result = response.json()
if 'idToken' in result:
st.session_state.logged_in = True
st.session_state.current_user = user.uid
st.session_state.display_name = user.display_name # Store the display name
st.success("Logged in successfully!")
return
raise Exception("User not found with provided credentials.") # if not found, raise exception.
except Exception as e:
st.error(f"Login failed: {e}") # if any error, display this message.
else:
raise Exception("Error with sign-in endpoint") # If sign-in endpoint doesn't return error or id token, then throw this error.
except Exception as e:
st.error(f"Login failed: {e}")
# Callback for logout
def logout_callback():
st.session_state.logged_in = False
st.session_state.current_user = None
st.session_state.display_name = None
st.session_state.selected_image = None
st.info("Logged out successfully!")
# Function to get image from url
def get_image_from_url(url):
"""
Fetches and returns an image from a given URL, converting to PNG if needed.
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image, url # Return the image and the URL
except requests.exceptions.RequestException as e:
return f"Error fetching image: {e}", None
except Exception as e:
return f"Error processing image: {e}", None
# Function to generate image
def generate_image(prompt, aspect_ratio, realism):
global token_id
global TOKEN
global tokens_tried
global no_of_accounts
global model_id
payload = {
"id": model_id,
"inputs": [prompt, aspect_ratio, str(realism).lower()],
}
headers = {"Authorization": f"Bearer {TOKEN}"}
try:
response_data = requests.post(API_URL, json=payload, headers=headers).json()
if "error" in response_data:
if 'error 429' in response_data['error']:
if tokens_tried < no_of_accounts:
token_id = (token_id + 1) % (no_of_accounts)
tokens_tried += 1
TOKEN = os.getenv(f"TOKEN{token_id}")
response_data = generate_image(prompt, aspect_ratio, realism)
tokens_tried = 0
return response_data
return "No credits available", None
return response_data, None
elif "output" in response_data:
url = response_data['output']
image, url = get_image_from_url(url)
return image, url # Return the image and the URL
else:
return "Error: Unexpected response from server", None
except Exception as e:
return f"Error", None
def download_image(image_url):
if not image_url:
return None # Return None if image_url is empty
try:
response = requests.get(image_url, stream=True)
response.raise_for_status()
# Get the content type from the headers
content_type = response.headers.get('content-type')
extension = mimetypes.guess_extension(content_type)
if not extension:
extension = ".png" # Default to .png if can't determine the extension
# Create a temporary file with the correct extension
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file:
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
temp_file_path = tmp_file.name
return temp_file_path
except Exception as e:
return None
# Function to store image and related data in Firebase
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url):
try:
ref = db.reference(f'users/{user_id}/images')
new_image_ref = ref.push()
new_image_ref.set({
'prompt': prompt,
'aspect_ratio': aspect_ratio,
'realism': realism,
'image_url': image_url,
'thumbnail_url' : thumbnail_url,
'timestamp': {'.sv': 'timestamp'}
})
st.success("Image data saved successfully!")
except Exception as e:
st.error(f"Failed to save image data: {e}")
#Function to upload image to cloud storage
def upload_image_to_storage(image, user_id, is_thumbnail = False):
try:
bucket = storage.bucket()
image_id = str(uuid.uuid4())
if is_thumbnail:
file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail
else:
file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images
blob = bucket.blob(file_path)
# Convert PIL Image to BytesIO object
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
blob.upload_from_string(img_byte_arr, content_type='image/png')
blob.make_public()
image_url = blob.public_url
return image_url
except Exception as e:
st.error(f"Failed to upload image to cloud storage: {e}")
return None
#Function to load image data from the database
def load_image_data(user_id, start_index, batch_size):
try:
ref = db.reference(f'users/{user_id}/images')
snapshot = ref.order_by_child('timestamp').limit_to_last(start_index + batch_size).get()
if snapshot:
image_list = list(snapshot.items())
image_list.reverse() # Reverse to show latest first
new_images = []
for key, val in image_list[start_index:]:
new_images.append(val)
return new_images
else:
return []
except Exception as e:
st.error(f"Failed to fetch image data from database: {e}")
return []
# Function to create low resolution thumbnail
def create_thumbnail(image, thumbnail_size = (150,150)):
try:
img_byte_arr = BytesIO()
image.thumbnail(thumbnail_size)
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image
return thumbnail
except Exception as e:
st.error(f"Failed to create thumbnail: {e}")
return None
# Registration form
def registration_form():
with st.form("Registration"):
st.subheader("Register")
email = st.text_input("Email", key="reg_email")
display_name = st.text_input("Display Name", key="reg_display_name")
password = st.text_input("Password (min 6 characters)", type="password", key="reg_password")
submit_button = st.form_submit_button("Register", on_click=register_callback)
# Login form
def login_form():
with st.form("Login"):
st.subheader("Login")
login_identifier = st.text_input("Email or Username", key="login_identifier")
password = st.text_input("Password", type="password", key="login_password")
submit_button = st.form_submit_button("Login", on_click=login_callback)
def main_app():
st.subheader(f"Welcome, {st.session_state.display_name}!")
st.write("Enter a prompt below to generate an image.")
# Input fields
prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate")
aspect_ratio = st.radio(
"Aspect Ratio",
options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],
index=5
)
realism = st.checkbox("Realism", value=False)
if st.button("Generate Image"):
if prompt:
with st.spinner("Generating Image..."):
image_result = generate_image(prompt, aspect_ratio, realism)
if isinstance(image_result, tuple) and len(image_result) == 2:
image, image_url = image_result
if isinstance(image, Image.Image):
# Define the boundary size
preview_size = 400
# Get original image dimensions
original_width, original_height = image.size
# Calculate scaling factor to fit within the boundary
width_ratio = preview_size / original_width
height_ratio = preview_size / original_height
scaling_factor = min(width_ratio, height_ratio)
# Calculate new dimensions
new_width = int(original_width * scaling_factor)
new_height = int(original_height * scaling_factor)
# Resize the image
resized_image = image.resize((new_width, new_height), Image.LANCZOS)
# Upload the high-resolution image
cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False)
if cloud_storage_url:
# Create thumbnail from the high-resolution image
thumbnail = create_thumbnail(image)
if thumbnail:
# Upload thumbnail to cloud storage and store url
thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True)
if thumbnail_url:
# Store image data in database
store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url)
st.success("Image stored to database successfully!")
with st.container(border=True):
st.image(resized_image, use_column_width=False) # Display the resized image
st.write(f"**Prompt:** {prompt}")
st.write(f"**Aspect Ratio:** {aspect_ratio}")
st.write(f"**Realism:** {realism}")
download_path = download_image(image_url)
if download_path:
st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_high_res_{uuid.uuid4()}")
else:
st.error("Failed to upload thumbnail to cloud storage.")
else:
st.error("Failed to create thumbnail")
else:
st.error("Failed to upload image to cloud storage.")
else:
st.error(f"Image generation failed: {image}")
else:
st.error(f"Image generation failed: {image_result}")
else:
st.warning("Please enter a prompt to generate an image.")
st.header("Your Generated Images")
# Initialize the current window, if it doesn't exist in session state
if "current_window_start" not in st.session_state:
st.session_state.current_window_start = 0
if "window_size" not in st.session_state:
st.session_state.window_size = 5 # The number of images to display at a time
if "selected_image" not in st.session_state:
st.session_state.selected_image = None
# Create left and right arrow buttons
col_left, col_center, col_right = st.columns([1,8,1])
with col_left:
if st.button("◀️"):
st.session_state.current_window_start = max(0, st.session_state.current_window_start - st.session_state.window_size)
with col_right:
if st.button("▶️"):
st.session_state.current_window_start += st.session_state.window_size
# Dynamically load images for the window
all_images = load_image_data(st.session_state.current_user, 0, 1000) # load all images
if all_images:
num_images = len(all_images)
# Calculate the range for images to display
start_index = st.session_state.current_window_start
end_index = min(start_index + st.session_state.window_size, num_images)
images_for_window = all_images[start_index:end_index]
# Setup columns for horizontal slider layout
num_images_to_display = len(images_for_window)
cols = st.columns(num_images_to_display)
for i, image_data in enumerate(images_for_window):
with cols[i]:
if image_data.get('thumbnail_url') and image_data.get('image_url'):
if st.button("More", key = f"more_{i}"):
st.session_state.selected_image = image_data
st.image(image_data['thumbnail_url'], width = 150) #display thumbnail
else:
st.image(image_data['image_url'], width = 150)
st.write(f"**Prompt:** {image_data['prompt']}")
st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}")
st.write(f"**Realism:** {image_data['realism']}")
st.markdown("---")
else:
st.write("No image generated yet!")
# Display modal if an image is selected
if st.session_state.selected_image:
with st.container(border = True):
st.image(st.session_state.selected_image['image_url'], use_column_width=True)
st.write(f"**Prompt:** {st.session_state.selected_image['prompt']}")
st.write(f"**Aspect Ratio:** {st.session_state.selected_image['aspect_ratio']}")
st.write(f"**Realism:** {st.session_state.selected_image['realism']}")
download_path = download_image(st.session_state.selected_image['image_url'])
if download_path:
st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_overlay_{uuid.uuid4()}")
if st.button("Close"):
st.session_state.selected_image = None # close the modal when "close" is clicked
# Logout button
if st.button("Logout", on_click=logout_callback):
pass
if st.session_state.logged_in:
main_app()
else:
registration_form()
login_form()