import streamlit as st import firebase_admin from firebase_admin import credentials, auth, db, storage import os import json import requests from io import BytesIO from PIL import Image import tempfile import mimetypes import uuid import io # Load Firebase credentials from Hugging Face Secrets firebase_creds = os.getenv("FIREBASE_CREDENTIALS") FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY") FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET") if firebase_creds: firebase_creds = json.loads(firebase_creds) else: st.error("Firebase credentials not found. Please check your secrets.") # Initialize Firebase (only once) if not firebase_admin._apps: cred = credentials.Certificate(firebase_creds) firebase_admin.initialize_app(cred, { 'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/', 'storageBucket': FIREBASE_STORAGE_BUCKET }) # Initialize session state if "logged_in" not in st.session_state: st.session_state.logged_in = False if "current_user" not in st.session_state: st.session_state.current_user = None if "display_name" not in st.session_state: st.session_state.display_name = None if "window_size" not in st.session_state: st.session_state.window_size = 5 if "current_window_start" not in st.session_state: st.session_state.current_window_start = 0 if "selected_image" not in st.session_state: st.session_state.selected_image = None TOKEN = os.getenv("TOKEN0") API_URL = os.getenv("API_URL") token_id = 0 tokens_tried = 0 no_of_accounts = 7 model_id = os.getenv("MODEL_ID") def send_verification_email(id_token): url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}' headers = {'Content-Type': 'application/json'} data = { 'requestType': 'VERIFY_EMAIL', 'idToken': id_token } response = requests.post(url, headers=headers, json=data) result = response.json() if 'error' in result: return {'status': 'error', 'message': result['error']['message']} else: return {'status': 'success', 'email': result['email']} # Callback for registration def register_callback(): email = st.session_state.reg_email password = st.session_state.reg_password display_name = st.session_state.reg_display_name try: # Step 1: Create a new user in Firebase user = auth.create_user(email=email, password=password) # Step 2: Update the user profile with the display name auth.update_user(user.uid, display_name=display_name) st.success("Registration successful! Sending verification email...") # Step 3: Sign in the user programmatically to get the id_token url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' data = { 'email': email, 'password': password, 'returnSecureToken': True } response = requests.post(url, json=data) result = response.json() if 'idToken' in result: id_token = result['idToken'] st.session_state.id_token = id_token verification_result = send_verification_email(id_token) if verification_result['status'] == 'success': st.success(f"Verification email sent to {email}.") else: st.error(f"Failed to send verification email: {verification_result['message']}") else: st.error(f"Failed to retrieve id_token: {result['error']['message']}") except Exception as e: st.error(f"Registration failed: {e}") # Callback for login def login_callback(): login_identifier = st.session_state.login_identifier password = st.session_state.login_password try: # Try to sign in the user programmatically to check the password validity url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' data = { 'email': login_identifier, 'password': password, 'returnSecureToken': True } response = requests.post(url, json=data) result = response.json() if 'idToken' in result: # If sign in was successful, then use email to fetch the user user = auth.get_user_by_email(login_identifier) st.session_state.logged_in = True st.session_state.current_user = user.uid st.session_state.display_name = user.display_name # Store the display name st.success("Logged in successfully!") elif 'error' in result: # If sign-in fails, retrieve user using display name try: user_list = auth.list_users() for user_info in user_list.users: if user_info.display_name == login_identifier: user = user_info # If user is found using display name, try signing in using email url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' data = { 'email': user.email, 'password': password, 'returnSecureToken': True } response = requests.post(url, json=data) result = response.json() if 'idToken' in result: st.session_state.logged_in = True st.session_state.current_user = user.uid st.session_state.display_name = user.display_name # Store the display name st.success("Logged in successfully!") return raise Exception("User not found with provided credentials.") # if not found, raise exception. except Exception as e: st.error(f"Login failed: {e}") # if any error, display this message. else: raise Exception("Error with sign-in endpoint") # If sign-in endpoint doesn't return error or id token, then throw this error. except Exception as e: st.error(f"Login failed: {e}") # Callback for logout def logout_callback(): st.session_state.logged_in = False st.session_state.current_user = None st.session_state.display_name = None st.session_state.selected_image = None st.info("Logged out successfully!") # Function to get image from url def get_image_from_url(url): """ Fetches and returns an image from a given URL, converting to PNG if needed. """ try: response = requests.get(url, stream=True) response.raise_for_status() image = Image.open(BytesIO(response.content)) return image, url # Return the image and the URL except requests.exceptions.RequestException as e: return f"Error fetching image: {e}", None except Exception as e: return f"Error processing image: {e}", None # Function to generate image def generate_image(prompt, aspect_ratio, realism): global token_id global TOKEN global tokens_tried global no_of_accounts global model_id payload = { "id": model_id, "inputs": [prompt, aspect_ratio, str(realism).lower()], } headers = {"Authorization": f"Bearer {TOKEN}"} try: response_data = requests.post(API_URL, json=payload, headers=headers).json() if "error" in response_data: if 'error 429' in response_data['error']: if tokens_tried < no_of_accounts: token_id = (token_id + 1) % (no_of_accounts) tokens_tried += 1 TOKEN = os.getenv(f"TOKEN{token_id}") response_data = generate_image(prompt, aspect_ratio, realism) tokens_tried = 0 return response_data return "No credits available", None return response_data, None elif "output" in response_data: url = response_data['output'] image, url = get_image_from_url(url) return image, url # Return the image and the URL else: return "Error: Unexpected response from server", None except Exception as e: return f"Error", None def download_image(image_url): if not image_url: return None # Return None if image_url is empty try: response = requests.get(image_url, stream=True) response.raise_for_status() # Get the content type from the headers content_type = response.headers.get('content-type') extension = mimetypes.guess_extension(content_type) if not extension: extension = ".png" # Default to .png if can't determine the extension # Create a temporary file with the correct extension with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file: for chunk in response.iter_content(chunk_size=8192): tmp_file.write(chunk) temp_file_path = tmp_file.name return temp_file_path except Exception as e: return None # Function to store image and related data in Firebase def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url): try: ref = db.reference(f'users/{user_id}/images') new_image_ref = ref.push() new_image_ref.set({ 'prompt': prompt, 'aspect_ratio': aspect_ratio, 'realism': realism, 'image_url': image_url, 'thumbnail_url' : thumbnail_url, 'timestamp': {'.sv': 'timestamp'} }) st.success("Image data saved successfully!") except Exception as e: st.error(f"Failed to save image data: {e}") #Function to upload image to cloud storage def upload_image_to_storage(image, user_id, is_thumbnail = False): try: bucket = storage.bucket() image_id = str(uuid.uuid4()) if is_thumbnail: file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail else: file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images blob = bucket.blob(file_path) # Convert PIL Image to BytesIO object img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() blob.upload_from_string(img_byte_arr, content_type='image/png') blob.make_public() image_url = blob.public_url return image_url except Exception as e: st.error(f"Failed to upload image to cloud storage: {e}") return None #Function to load image data from the database def load_image_data(user_id, start_index, batch_size): try: ref = db.reference(f'users/{user_id}/images') snapshot = ref.order_by_child('timestamp').limit_to_last(start_index + batch_size).get() if snapshot: image_list = list(snapshot.items()) image_list.reverse() # Reverse to show latest first new_images = [] for key, val in image_list[start_index:]: new_images.append(val) return new_images else: return [] except Exception as e: st.error(f"Failed to fetch image data from database: {e}") return [] # Function to create low resolution thumbnail def create_thumbnail(image, thumbnail_size = (150,150)): try: img_byte_arr = BytesIO() image.thumbnail(thumbnail_size) image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image return thumbnail except Exception as e: st.error(f"Failed to create thumbnail: {e}") return None # Registration form def registration_form(): with st.form("Registration"): st.subheader("Register") email = st.text_input("Email", key="reg_email") display_name = st.text_input("Display Name", key="reg_display_name") password = st.text_input("Password (min 6 characters)", type="password", key="reg_password") submit_button = st.form_submit_button("Register", on_click=register_callback) # Login form def login_form(): with st.form("Login"): st.subheader("Login") login_identifier = st.text_input("Email or Username", key="login_identifier") password = st.text_input("Password", type="password", key="login_password") submit_button = st.form_submit_button("Login", on_click=login_callback) def main_app(): st.subheader(f"Welcome, {st.session_state.display_name}!") st.write("Enter a prompt below to generate an image.") # Input fields prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate") aspect_ratio = st.radio( "Aspect Ratio", options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"], index=5 ) realism = st.checkbox("Realism", value=False) if st.button("Generate Image"): if prompt: with st.spinner("Generating Image..."): image_result = generate_image(prompt, aspect_ratio, realism) if isinstance(image_result, tuple) and len(image_result) == 2: image, image_url = image_result if isinstance(image, Image.Image): # Define the boundary size preview_size = 400 # Get original image dimensions original_width, original_height = image.size # Calculate scaling factor to fit within the boundary width_ratio = preview_size / original_width height_ratio = preview_size / original_height scaling_factor = min(width_ratio, height_ratio) # Calculate new dimensions new_width = int(original_width * scaling_factor) new_height = int(original_height * scaling_factor) # Resize the image resized_image = image.resize((new_width, new_height), Image.LANCZOS) # Upload the high-resolution image cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False) if cloud_storage_url: # Create thumbnail from the high-resolution image thumbnail = create_thumbnail(image) if thumbnail: # Upload thumbnail to cloud storage and store url thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True) if thumbnail_url: # Store image data in database store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url) st.success("Image stored to database successfully!") with st.container(border=True): st.image(resized_image, use_column_width=False) # Display the resized image st.write(f"**Prompt:** {prompt}") st.write(f"**Aspect Ratio:** {aspect_ratio}") st.write(f"**Realism:** {realism}") download_path = download_image(image_url) if download_path: st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_high_res_{uuid.uuid4()}") else: st.error("Failed to upload thumbnail to cloud storage.") else: st.error("Failed to create thumbnail") else: st.error("Failed to upload image to cloud storage.") else: st.error(f"Image generation failed: {image}") else: st.error(f"Image generation failed: {image_result}") else: st.warning("Please enter a prompt to generate an image.") st.header("Your Generated Images") # Initialize the current window, if it doesn't exist in session state if "current_window_start" not in st.session_state: st.session_state.current_window_start = 0 if "window_size" not in st.session_state: st.session_state.window_size = 5 # The number of images to display at a time if "selected_image" not in st.session_state: st.session_state.selected_image = None # Create left and right arrow buttons col_left, col_center, col_right = st.columns([1,8,1]) with col_left: if st.button("◀️"): st.session_state.current_window_start = max(0, st.session_state.current_window_start - st.session_state.window_size) with col_right: if st.button("▶️"): st.session_state.current_window_start += st.session_state.window_size # Dynamically load images for the window all_images = load_image_data(st.session_state.current_user, 0, 1000) # load all images if all_images: num_images = len(all_images) # Calculate the range for images to display start_index = st.session_state.current_window_start end_index = min(start_index + st.session_state.window_size, num_images) images_for_window = all_images[start_index:end_index] # Setup columns for horizontal slider layout num_images_to_display = len(images_for_window) cols = st.columns(num_images_to_display) for i, image_data in enumerate(images_for_window): with cols[i]: if image_data.get('thumbnail_url') and image_data.get('image_url'): if st.button("More", key = f"more_{i}"): st.session_state.selected_image = image_data st.image(image_data['thumbnail_url'], width = 150) #display thumbnail else: st.image(image_data['image_url'], width = 150) st.write(f"**Prompt:** {image_data['prompt']}") st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}") st.write(f"**Realism:** {image_data['realism']}") st.markdown("---") else: st.write("No image generated yet!") # Display modal if an image is selected if st.session_state.selected_image: with st.container(border = True): st.image(st.session_state.selected_image['image_url'], use_column_width=True) st.write(f"**Prompt:** {st.session_state.selected_image['prompt']}") st.write(f"**Aspect Ratio:** {st.session_state.selected_image['aspect_ratio']}") st.write(f"**Realism:** {st.session_state.selected_image['realism']}") download_path = download_image(st.session_state.selected_image['image_url']) if download_path: st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_overlay_{uuid.uuid4()}") if st.button("Close"): st.session_state.selected_image = None # close the modal when "close" is clicked # Logout button if st.button("Logout", on_click=logout_callback): pass if st.session_state.logged_in: main_app() else: registration_form() login_form()