Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import firebase_admin | |
| from firebase_admin import credentials, auth, db, storage | |
| import os | |
| import json | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| import tempfile | |
| import mimetypes | |
| import uuid | |
| import io | |
| # Load Firebase credentials from Hugging Face Secrets | |
| firebase_creds = os.getenv("FIREBASE_CREDENTIALS") | |
| FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY") | |
| FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET") | |
| if firebase_creds: | |
| firebase_creds = json.loads(firebase_creds) | |
| else: | |
| st.error("Firebase credentials not found. Please check your secrets.") | |
| # Initialize Firebase (only once) | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate(firebase_creds) | |
| firebase_admin.initialize_app(cred, { | |
| 'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/', | |
| 'storageBucket': FIREBASE_STORAGE_BUCKET | |
| }) | |
| # Initialize session state | |
| if "logged_in" not in st.session_state: | |
| st.session_state.logged_in = False | |
| if "current_user" not in st.session_state: | |
| st.session_state.current_user = None | |
| if "display_name" not in st.session_state: | |
| st.session_state.display_name = None | |
| if "window_size" not in st.session_state: | |
| st.session_state.window_size = 5 | |
| if "current_window_start" not in st.session_state: | |
| st.session_state.current_window_start = 0 | |
| if "selected_image" not in st.session_state: | |
| st.session_state.selected_image = None | |
| TOKEN = os.getenv("TOKEN0") | |
| API_URL = os.getenv("API_URL") | |
| token_id = 0 | |
| tokens_tried = 0 | |
| no_of_accounts = 7 | |
| model_id = os.getenv("MODEL_ID") | |
| def send_verification_email(id_token): | |
| url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}' | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'requestType': 'VERIFY_EMAIL', | |
| 'idToken': id_token | |
| } | |
| response = requests.post(url, headers=headers, json=data) | |
| result = response.json() | |
| if 'error' in result: | |
| return {'status': 'error', 'message': result['error']['message']} | |
| else: | |
| return {'status': 'success', 'email': result['email']} | |
| # Callback for registration | |
| def register_callback(): | |
| email = st.session_state.reg_email | |
| password = st.session_state.reg_password | |
| display_name = st.session_state.reg_display_name | |
| try: | |
| # Step 1: Create a new user in Firebase | |
| user = auth.create_user(email=email, password=password) | |
| # Step 2: Update the user profile with the display name | |
| auth.update_user(user.uid, display_name=display_name) | |
| st.success("Registration successful! Sending verification email...") | |
| # Step 3: Sign in the user programmatically to get the id_token | |
| url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
| data = { | |
| 'email': email, | |
| 'password': password, | |
| 'returnSecureToken': True | |
| } | |
| response = requests.post(url, json=data) | |
| result = response.json() | |
| if 'idToken' in result: | |
| id_token = result['idToken'] | |
| st.session_state.id_token = id_token | |
| verification_result = send_verification_email(id_token) | |
| if verification_result['status'] == 'success': | |
| st.success(f"Verification email sent to {email}.") | |
| else: | |
| st.error(f"Failed to send verification email: {verification_result['message']}") | |
| else: | |
| st.error(f"Failed to retrieve id_token: {result['error']['message']}") | |
| except Exception as e: | |
| st.error(f"Registration failed: {e}") | |
| # Callback for login | |
| def login_callback(): | |
| login_identifier = st.session_state.login_identifier | |
| password = st.session_state.login_password | |
| try: | |
| # Try to sign in the user programmatically to check the password validity | |
| url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
| data = { | |
| 'email': login_identifier, | |
| 'password': password, | |
| 'returnSecureToken': True | |
| } | |
| response = requests.post(url, json=data) | |
| result = response.json() | |
| if 'idToken' in result: | |
| # If sign in was successful, then use email to fetch the user | |
| user = auth.get_user_by_email(login_identifier) | |
| st.session_state.logged_in = True | |
| st.session_state.current_user = user.uid | |
| st.session_state.display_name = user.display_name # Store the display name | |
| st.success("Logged in successfully!") | |
| elif 'error' in result: | |
| # If sign-in fails, retrieve user using display name | |
| try: | |
| user_list = auth.list_users() | |
| for user_info in user_list.users: | |
| if user_info.display_name == login_identifier: | |
| user = user_info | |
| # If user is found using display name, try signing in using email | |
| url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
| data = { | |
| 'email': user.email, | |
| 'password': password, | |
| 'returnSecureToken': True | |
| } | |
| response = requests.post(url, json=data) | |
| result = response.json() | |
| if 'idToken' in result: | |
| st.session_state.logged_in = True | |
| st.session_state.current_user = user.uid | |
| st.session_state.display_name = user.display_name # Store the display name | |
| st.success("Logged in successfully!") | |
| return | |
| raise Exception("User not found with provided credentials.") # if not found, raise exception. | |
| except Exception as e: | |
| st.error(f"Login failed: {e}") # if any error, display this message. | |
| else: | |
| raise Exception("Error with sign-in endpoint") # If sign-in endpoint doesn't return error or id token, then throw this error. | |
| except Exception as e: | |
| st.error(f"Login failed: {e}") | |
| # Callback for logout | |
| def logout_callback(): | |
| st.session_state.logged_in = False | |
| st.session_state.current_user = None | |
| st.session_state.display_name = None | |
| st.session_state.selected_image = None | |
| st.info("Logged out successfully!") | |
| # Function to get image from url | |
| def get_image_from_url(url): | |
| """ | |
| Fetches and returns an image from a given URL, converting to PNG if needed. | |
| """ | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)) | |
| return image, url # Return the image and the URL | |
| except requests.exceptions.RequestException as e: | |
| return f"Error fetching image: {e}", None | |
| except Exception as e: | |
| return f"Error processing image: {e}", None | |
| # Function to generate image | |
| def generate_image(prompt, aspect_ratio, realism): | |
| global token_id | |
| global TOKEN | |
| global tokens_tried | |
| global no_of_accounts | |
| global model_id | |
| payload = { | |
| "id": model_id, | |
| "inputs": [prompt, aspect_ratio, str(realism).lower()], | |
| } | |
| headers = {"Authorization": f"Bearer {TOKEN}"} | |
| try: | |
| response_data = requests.post(API_URL, json=payload, headers=headers).json() | |
| if "error" in response_data: | |
| if 'error 429' in response_data['error']: | |
| if tokens_tried < no_of_accounts: | |
| token_id = (token_id + 1) % (no_of_accounts) | |
| tokens_tried += 1 | |
| TOKEN = os.getenv(f"TOKEN{token_id}") | |
| response_data = generate_image(prompt, aspect_ratio, realism) | |
| tokens_tried = 0 | |
| return response_data | |
| return "No credits available", None | |
| return response_data, None | |
| elif "output" in response_data: | |
| url = response_data['output'] | |
| image, url = get_image_from_url(url) | |
| return image, url # Return the image and the URL | |
| else: | |
| return "Error: Unexpected response from server", None | |
| except Exception as e: | |
| return f"Error", None | |
| def download_image(image_url): | |
| if not image_url: | |
| return None # Return None if image_url is empty | |
| try: | |
| response = requests.get(image_url, stream=True) | |
| response.raise_for_status() | |
| # Get the content type from the headers | |
| content_type = response.headers.get('content-type') | |
| extension = mimetypes.guess_extension(content_type) | |
| if not extension: | |
| extension = ".png" # Default to .png if can't determine the extension | |
| # Create a temporary file with the correct extension | |
| with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| tmp_file.write(chunk) | |
| temp_file_path = tmp_file.name | |
| return temp_file_path | |
| except Exception as e: | |
| return None | |
| # Function to store image and related data in Firebase | |
| def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url): | |
| try: | |
| ref = db.reference(f'users/{user_id}/images') | |
| new_image_ref = ref.push() | |
| new_image_ref.set({ | |
| 'prompt': prompt, | |
| 'aspect_ratio': aspect_ratio, | |
| 'realism': realism, | |
| 'image_url': image_url, | |
| 'thumbnail_url' : thumbnail_url, | |
| 'timestamp': {'.sv': 'timestamp'} | |
| }) | |
| st.success("Image data saved successfully!") | |
| except Exception as e: | |
| st.error(f"Failed to save image data: {e}") | |
| #Function to upload image to cloud storage | |
| def upload_image_to_storage(image, user_id, is_thumbnail = False): | |
| try: | |
| bucket = storage.bucket() | |
| image_id = str(uuid.uuid4()) | |
| if is_thumbnail: | |
| file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail | |
| else: | |
| file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images | |
| blob = bucket.blob(file_path) | |
| # Convert PIL Image to BytesIO object | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| blob.upload_from_string(img_byte_arr, content_type='image/png') | |
| blob.make_public() | |
| image_url = blob.public_url | |
| return image_url | |
| except Exception as e: | |
| st.error(f"Failed to upload image to cloud storage: {e}") | |
| return None | |
| #Function to load image data from the database | |
| def load_image_data(user_id, start_index, batch_size): | |
| try: | |
| ref = db.reference(f'users/{user_id}/images') | |
| snapshot = ref.order_by_child('timestamp').limit_to_last(start_index + batch_size).get() | |
| if snapshot: | |
| image_list = list(snapshot.items()) | |
| image_list.reverse() # Reverse to show latest first | |
| new_images = [] | |
| for key, val in image_list[start_index:]: | |
| new_images.append(val) | |
| return new_images | |
| else: | |
| return [] | |
| except Exception as e: | |
| st.error(f"Failed to fetch image data from database: {e}") | |
| return [] | |
| # Function to create low resolution thumbnail | |
| def create_thumbnail(image, thumbnail_size = (150,150)): | |
| try: | |
| img_byte_arr = BytesIO() | |
| image.thumbnail(thumbnail_size) | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr = img_byte_arr.getvalue() | |
| thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image | |
| return thumbnail | |
| except Exception as e: | |
| st.error(f"Failed to create thumbnail: {e}") | |
| return None | |
| # Registration form | |
| def registration_form(): | |
| with st.form("Registration"): | |
| st.subheader("Register") | |
| email = st.text_input("Email", key="reg_email") | |
| display_name = st.text_input("Display Name", key="reg_display_name") | |
| password = st.text_input("Password (min 6 characters)", type="password", key="reg_password") | |
| submit_button = st.form_submit_button("Register", on_click=register_callback) | |
| # Login form | |
| def login_form(): | |
| with st.form("Login"): | |
| st.subheader("Login") | |
| login_identifier = st.text_input("Email or Username", key="login_identifier") | |
| password = st.text_input("Password", type="password", key="login_password") | |
| submit_button = st.form_submit_button("Login", on_click=login_callback) | |
| def main_app(): | |
| st.subheader(f"Welcome, {st.session_state.display_name}!") | |
| st.write("Enter a prompt below to generate an image.") | |
| # Input fields | |
| prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate") | |
| aspect_ratio = st.radio( | |
| "Aspect Ratio", | |
| options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"], | |
| index=5 | |
| ) | |
| realism = st.checkbox("Realism", value=False) | |
| if st.button("Generate Image"): | |
| if prompt: | |
| with st.spinner("Generating Image..."): | |
| image_result = generate_image(prompt, aspect_ratio, realism) | |
| if isinstance(image_result, tuple) and len(image_result) == 2: | |
| image, image_url = image_result | |
| if isinstance(image, Image.Image): | |
| # Define the boundary size | |
| preview_size = 400 | |
| # Get original image dimensions | |
| original_width, original_height = image.size | |
| # Calculate scaling factor to fit within the boundary | |
| width_ratio = preview_size / original_width | |
| height_ratio = preview_size / original_height | |
| scaling_factor = min(width_ratio, height_ratio) | |
| # Calculate new dimensions | |
| new_width = int(original_width * scaling_factor) | |
| new_height = int(original_height * scaling_factor) | |
| # Resize the image | |
| resized_image = image.resize((new_width, new_height), Image.LANCZOS) | |
| # Upload the high-resolution image | |
| cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False) | |
| if cloud_storage_url: | |
| # Create thumbnail from the high-resolution image | |
| thumbnail = create_thumbnail(image) | |
| if thumbnail: | |
| # Upload thumbnail to cloud storage and store url | |
| thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True) | |
| if thumbnail_url: | |
| # Store image data in database | |
| store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url) | |
| st.success("Image stored to database successfully!") | |
| with st.container(border=True): | |
| st.image(resized_image, use_column_width=False) # Display the resized image | |
| st.write(f"**Prompt:** {prompt}") | |
| st.write(f"**Aspect Ratio:** {aspect_ratio}") | |
| st.write(f"**Realism:** {realism}") | |
| download_path = download_image(image_url) | |
| if download_path: | |
| st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_high_res_{uuid.uuid4()}") | |
| else: | |
| st.error("Failed to upload thumbnail to cloud storage.") | |
| else: | |
| st.error("Failed to create thumbnail") | |
| else: | |
| st.error("Failed to upload image to cloud storage.") | |
| else: | |
| st.error(f"Image generation failed: {image}") | |
| else: | |
| st.error(f"Image generation failed: {image_result}") | |
| else: | |
| st.warning("Please enter a prompt to generate an image.") | |
| st.header("Your Generated Images") | |
| # Initialize the current window, if it doesn't exist in session state | |
| if "current_window_start" not in st.session_state: | |
| st.session_state.current_window_start = 0 | |
| if "window_size" not in st.session_state: | |
| st.session_state.window_size = 5 # The number of images to display at a time | |
| if "selected_image" not in st.session_state: | |
| st.session_state.selected_image = None | |
| # Create left and right arrow buttons | |
| col_left, col_center, col_right = st.columns([1,8,1]) | |
| with col_left: | |
| if st.button("◀️"): | |
| st.session_state.current_window_start = max(0, st.session_state.current_window_start - st.session_state.window_size) | |
| with col_right: | |
| if st.button("▶️"): | |
| st.session_state.current_window_start += st.session_state.window_size | |
| # Dynamically load images for the window | |
| all_images = load_image_data(st.session_state.current_user, 0, 1000) # load all images | |
| if all_images: | |
| num_images = len(all_images) | |
| # Calculate the range for images to display | |
| start_index = st.session_state.current_window_start | |
| end_index = min(start_index + st.session_state.window_size, num_images) | |
| images_for_window = all_images[start_index:end_index] | |
| # Setup columns for horizontal slider layout | |
| num_images_to_display = len(images_for_window) | |
| cols = st.columns(num_images_to_display) | |
| for i, image_data in enumerate(images_for_window): | |
| with cols[i]: | |
| if image_data.get('thumbnail_url') and image_data.get('image_url'): | |
| if st.button("More", key = f"more_{i}"): | |
| st.session_state.selected_image = image_data | |
| st.image(image_data['thumbnail_url'], width = 150) #display thumbnail | |
| else: | |
| st.image(image_data['image_url'], width = 150) | |
| st.write(f"**Prompt:** {image_data['prompt']}") | |
| st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}") | |
| st.write(f"**Realism:** {image_data['realism']}") | |
| st.markdown("---") | |
| else: | |
| st.write("No image generated yet!") | |
| # Display modal if an image is selected | |
| if st.session_state.selected_image: | |
| with st.container(border = True): | |
| st.image(st.session_state.selected_image['image_url'], use_column_width=True) | |
| st.write(f"**Prompt:** {st.session_state.selected_image['prompt']}") | |
| st.write(f"**Aspect Ratio:** {st.session_state.selected_image['aspect_ratio']}") | |
| st.write(f"**Realism:** {st.session_state.selected_image['realism']}") | |
| download_path = download_image(st.session_state.selected_image['image_url']) | |
| if download_path: | |
| st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_overlay_{uuid.uuid4()}") | |
| if st.button("Close"): | |
| st.session_state.selected_image = None # close the modal when "close" is clicked | |
| # Logout button | |
| if st.button("Logout", on_click=logout_callback): | |
| pass | |
| if st.session_state.logged_in: | |
| main_app() | |
| else: | |
| registration_form() | |
| login_form() |