Spaces:
Sleeping
Sleeping
import streamlit as st | |
import firebase_admin | |
from firebase_admin import credentials, auth, db, storage | |
import os | |
import json | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
import tempfile | |
import mimetypes | |
import uuid | |
import io | |
# Load Firebase credentials from Hugging Face Secrets | |
firebase_creds = os.getenv("FIREBASE_CREDENTIALS") | |
FIREBASE_API_KEY = os.getenv("FIREBASE_API_KEY") | |
FIREBASE_STORAGE_BUCKET = os.getenv("FIREBASE_STORAGE_BUCKET") | |
if firebase_creds: | |
firebase_creds = json.loads(firebase_creds) | |
else: | |
st.error("Firebase credentials not found. Please check your secrets.") | |
# Initialize Firebase (only once) | |
if not firebase_admin._apps: | |
cred = credentials.Certificate(firebase_creds) | |
firebase_admin.initialize_app(cred, { | |
'databaseURL': 'https://creative-623ef-default-rtdb.firebaseio.com/', | |
'storageBucket': FIREBASE_STORAGE_BUCKET | |
}) | |
# Initialize session state | |
if "logged_in" not in st.session_state: | |
st.session_state.logged_in = False | |
if "current_user" not in st.session_state: | |
st.session_state.current_user = None | |
if "display_name" not in st.session_state: | |
st.session_state.display_name = None | |
if "window_size" not in st.session_state: | |
st.session_state.window_size = 5 | |
if "current_window_start" not in st.session_state: | |
st.session_state.current_window_start = 0 | |
if "selected_image" not in st.session_state: | |
st.session_state.selected_image = None | |
TOKEN = os.getenv("TOKEN0") | |
API_URL = os.getenv("API_URL") | |
token_id = 0 | |
tokens_tried = 0 | |
no_of_accounts = 7 | |
model_id = os.getenv("MODEL_ID") | |
def send_verification_email(id_token): | |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:sendOobCode?key={FIREBASE_API_KEY}' | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'requestType': 'VERIFY_EMAIL', | |
'idToken': id_token | |
} | |
response = requests.post(url, headers=headers, json=data) | |
result = response.json() | |
if 'error' in result: | |
return {'status': 'error', 'message': result['error']['message']} | |
else: | |
return {'status': 'success', 'email': result['email']} | |
# Callback for registration | |
def register_callback(): | |
email = st.session_state.reg_email | |
password = st.session_state.reg_password | |
display_name = st.session_state.reg_display_name | |
try: | |
# Step 1: Create a new user in Firebase | |
user = auth.create_user(email=email, password=password) | |
# Step 2: Update the user profile with the display name | |
auth.update_user(user.uid, display_name=display_name) | |
st.success("Registration successful! Sending verification email...") | |
# Step 3: Sign in the user programmatically to get the id_token | |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
data = { | |
'email': email, | |
'password': password, | |
'returnSecureToken': True | |
} | |
response = requests.post(url, json=data) | |
result = response.json() | |
if 'idToken' in result: | |
id_token = result['idToken'] | |
st.session_state.id_token = id_token | |
verification_result = send_verification_email(id_token) | |
if verification_result['status'] == 'success': | |
st.success(f"Verification email sent to {email}.") | |
else: | |
st.error(f"Failed to send verification email: {verification_result['message']}") | |
else: | |
st.error(f"Failed to retrieve id_token: {result['error']['message']}") | |
except Exception as e: | |
st.error(f"Registration failed: {e}") | |
# Callback for login | |
def login_callback(): | |
login_identifier = st.session_state.login_identifier | |
password = st.session_state.login_password | |
try: | |
# Try to sign in the user programmatically to check the password validity | |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
data = { | |
'email': login_identifier, | |
'password': password, | |
'returnSecureToken': True | |
} | |
response = requests.post(url, json=data) | |
result = response.json() | |
if 'idToken' in result: | |
# If sign in was successful, then use email to fetch the user | |
user = auth.get_user_by_email(login_identifier) | |
st.session_state.logged_in = True | |
st.session_state.current_user = user.uid | |
st.session_state.display_name = user.display_name # Store the display name | |
st.success("Logged in successfully!") | |
elif 'error' in result: | |
# If sign-in fails, retrieve user using display name | |
try: | |
user_list = auth.list_users() | |
for user_info in user_list.users: | |
if user_info.display_name == login_identifier: | |
user = user_info | |
# If user is found using display name, try signing in using email | |
url = f'https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={FIREBASE_API_KEY}' | |
data = { | |
'email': user.email, | |
'password': password, | |
'returnSecureToken': True | |
} | |
response = requests.post(url, json=data) | |
result = response.json() | |
if 'idToken' in result: | |
st.session_state.logged_in = True | |
st.session_state.current_user = user.uid | |
st.session_state.display_name = user.display_name # Store the display name | |
st.success("Logged in successfully!") | |
return | |
raise Exception("User not found with provided credentials.") # if not found, raise exception. | |
except Exception as e: | |
st.error(f"Login failed: {e}") # if any error, display this message. | |
else: | |
raise Exception("Error with sign-in endpoint") # If sign-in endpoint doesn't return error or id token, then throw this error. | |
except Exception as e: | |
st.error(f"Login failed: {e}") | |
# Callback for logout | |
def logout_callback(): | |
st.session_state.logged_in = False | |
st.session_state.current_user = None | |
st.session_state.display_name = None | |
st.session_state.selected_image = None | |
st.info("Logged out successfully!") | |
# Function to get image from url | |
def get_image_from_url(url): | |
""" | |
Fetches and returns an image from a given URL, converting to PNG if needed. | |
""" | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
image = Image.open(BytesIO(response.content)) | |
return image, url # Return the image and the URL | |
except requests.exceptions.RequestException as e: | |
return f"Error fetching image: {e}", None | |
except Exception as e: | |
return f"Error processing image: {e}", None | |
# Function to generate image | |
def generate_image(prompt, aspect_ratio, realism): | |
global token_id | |
global TOKEN | |
global tokens_tried | |
global no_of_accounts | |
global model_id | |
payload = { | |
"id": model_id, | |
"inputs": [prompt, aspect_ratio, str(realism).lower()], | |
} | |
headers = {"Authorization": f"Bearer {TOKEN}"} | |
try: | |
response_data = requests.post(API_URL, json=payload, headers=headers).json() | |
if "error" in response_data: | |
if 'error 429' in response_data['error']: | |
if tokens_tried < no_of_accounts: | |
token_id = (token_id + 1) % (no_of_accounts) | |
tokens_tried += 1 | |
TOKEN = os.getenv(f"TOKEN{token_id}") | |
response_data = generate_image(prompt, aspect_ratio, realism) | |
tokens_tried = 0 | |
return response_data | |
return "No credits available", None | |
return response_data, None | |
elif "output" in response_data: | |
url = response_data['output'] | |
image, url = get_image_from_url(url) | |
return image, url # Return the image and the URL | |
else: | |
return "Error: Unexpected response from server", None | |
except Exception as e: | |
return f"Error", None | |
def download_image(image_url): | |
if not image_url: | |
return None # Return None if image_url is empty | |
try: | |
response = requests.get(image_url, stream=True) | |
response.raise_for_status() | |
# Get the content type from the headers | |
content_type = response.headers.get('content-type') | |
extension = mimetypes.guess_extension(content_type) | |
if not extension: | |
extension = ".png" # Default to .png if can't determine the extension | |
# Create a temporary file with the correct extension | |
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as tmp_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_file.write(chunk) | |
temp_file_path = tmp_file.name | |
return temp_file_path | |
except Exception as e: | |
return None | |
# Function to store image and related data in Firebase | |
def store_image_data_in_db(user_id, prompt, aspect_ratio, realism, image_url, thumbnail_url): | |
try: | |
ref = db.reference(f'users/{user_id}/images') | |
new_image_ref = ref.push() | |
new_image_ref.set({ | |
'prompt': prompt, | |
'aspect_ratio': aspect_ratio, | |
'realism': realism, | |
'image_url': image_url, | |
'thumbnail_url' : thumbnail_url, | |
'timestamp': {'.sv': 'timestamp'} | |
}) | |
st.success("Image data saved successfully!") | |
except Exception as e: | |
st.error(f"Failed to save image data: {e}") | |
#Function to upload image to cloud storage | |
def upload_image_to_storage(image, user_id, is_thumbnail = False): | |
try: | |
bucket = storage.bucket() | |
image_id = str(uuid.uuid4()) | |
if is_thumbnail: | |
file_path = f"user_images/{user_id}/thumbnails/{image_id}.png" # path for thumbnail | |
else: | |
file_path = f"user_images/{user_id}/{image_id}.png" # path for high resolution images | |
blob = bucket.blob(file_path) | |
# Convert PIL Image to BytesIO object | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
blob.upload_from_string(img_byte_arr, content_type='image/png') | |
blob.make_public() | |
image_url = blob.public_url | |
return image_url | |
except Exception as e: | |
st.error(f"Failed to upload image to cloud storage: {e}") | |
return None | |
#Function to load image data from the database | |
def load_image_data(user_id, start_index, batch_size): | |
try: | |
ref = db.reference(f'users/{user_id}/images') | |
snapshot = ref.order_by_child('timestamp').limit_to_last(start_index + batch_size).get() | |
if snapshot: | |
image_list = list(snapshot.items()) | |
image_list.reverse() # Reverse to show latest first | |
new_images = [] | |
for key, val in image_list[start_index:]: | |
new_images.append(val) | |
return new_images | |
else: | |
return [] | |
except Exception as e: | |
st.error(f"Failed to fetch image data from database: {e}") | |
return [] | |
# Function to create low resolution thumbnail | |
def create_thumbnail(image, thumbnail_size = (150,150)): | |
try: | |
img_byte_arr = BytesIO() | |
image.thumbnail(thumbnail_size) | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
thumbnail = Image.open(io.BytesIO(img_byte_arr)) # convert byte to PIL image | |
return thumbnail | |
except Exception as e: | |
st.error(f"Failed to create thumbnail: {e}") | |
return None | |
# Registration form | |
def registration_form(): | |
with st.form("Registration"): | |
st.subheader("Register") | |
email = st.text_input("Email", key="reg_email") | |
display_name = st.text_input("Display Name", key="reg_display_name") | |
password = st.text_input("Password (min 6 characters)", type="password", key="reg_password") | |
submit_button = st.form_submit_button("Register", on_click=register_callback) | |
# Login form | |
def login_form(): | |
with st.form("Login"): | |
st.subheader("Login") | |
login_identifier = st.text_input("Email or Username", key="login_identifier") | |
password = st.text_input("Password", type="password", key="login_password") | |
submit_button = st.form_submit_button("Login", on_click=login_callback) | |
def main_app(): | |
st.subheader(f"Welcome, {st.session_state.display_name}!") | |
st.write("Enter a prompt below to generate an image.") | |
# Input fields | |
prompt = st.text_input("Prompt", key="image_prompt", placeholder="Describe the image you want to generate") | |
aspect_ratio = st.radio( | |
"Aspect Ratio", | |
options=["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"], | |
index=5 | |
) | |
realism = st.checkbox("Realism", value=False) | |
if st.button("Generate Image"): | |
if prompt: | |
with st.spinner("Generating Image..."): | |
image_result = generate_image(prompt, aspect_ratio, realism) | |
if isinstance(image_result, tuple) and len(image_result) == 2: | |
image, image_url = image_result | |
if isinstance(image, Image.Image): | |
# Define the boundary size | |
preview_size = 400 | |
# Get original image dimensions | |
original_width, original_height = image.size | |
# Calculate scaling factor to fit within the boundary | |
width_ratio = preview_size / original_width | |
height_ratio = preview_size / original_height | |
scaling_factor = min(width_ratio, height_ratio) | |
# Calculate new dimensions | |
new_width = int(original_width * scaling_factor) | |
new_height = int(original_height * scaling_factor) | |
# Resize the image | |
resized_image = image.resize((new_width, new_height), Image.LANCZOS) | |
# Upload the high-resolution image | |
cloud_storage_url = upload_image_to_storage(image, st.session_state.current_user, is_thumbnail=False) | |
if cloud_storage_url: | |
# Create thumbnail from the high-resolution image | |
thumbnail = create_thumbnail(image) | |
if thumbnail: | |
# Upload thumbnail to cloud storage and store url | |
thumbnail_url = upload_image_to_storage(thumbnail, st.session_state.current_user, is_thumbnail=True) | |
if thumbnail_url: | |
# Store image data in database | |
store_image_data_in_db(st.session_state.current_user, prompt, aspect_ratio, realism, cloud_storage_url, thumbnail_url) | |
st.success("Image stored to database successfully!") | |
with st.container(border=True): | |
st.image(resized_image, use_column_width=False) # Display the resized image | |
st.write(f"**Prompt:** {prompt}") | |
st.write(f"**Aspect Ratio:** {aspect_ratio}") | |
st.write(f"**Realism:** {realism}") | |
download_path = download_image(image_url) | |
if download_path: | |
st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_high_res_{uuid.uuid4()}") | |
else: | |
st.error("Failed to upload thumbnail to cloud storage.") | |
else: | |
st.error("Failed to create thumbnail") | |
else: | |
st.error("Failed to upload image to cloud storage.") | |
else: | |
st.error(f"Image generation failed: {image}") | |
else: | |
st.error(f"Image generation failed: {image_result}") | |
else: | |
st.warning("Please enter a prompt to generate an image.") | |
st.header("Your Generated Images") | |
# Initialize the current window, if it doesn't exist in session state | |
if "current_window_start" not in st.session_state: | |
st.session_state.current_window_start = 0 | |
if "window_size" not in st.session_state: | |
st.session_state.window_size = 5 # The number of images to display at a time | |
if "selected_image" not in st.session_state: | |
st.session_state.selected_image = None | |
# Create left and right arrow buttons | |
col_left, col_center, col_right = st.columns([1,8,1]) | |
with col_left: | |
if st.button("◀️"): | |
st.session_state.current_window_start = max(0, st.session_state.current_window_start - st.session_state.window_size) | |
with col_right: | |
if st.button("▶️"): | |
st.session_state.current_window_start += st.session_state.window_size | |
# Dynamically load images for the window | |
all_images = load_image_data(st.session_state.current_user, 0, 1000) # load all images | |
if all_images: | |
num_images = len(all_images) | |
# Calculate the range for images to display | |
start_index = st.session_state.current_window_start | |
end_index = min(start_index + st.session_state.window_size, num_images) | |
images_for_window = all_images[start_index:end_index] | |
# Setup columns for horizontal slider layout | |
num_images_to_display = len(images_for_window) | |
cols = st.columns(num_images_to_display) | |
for i, image_data in enumerate(images_for_window): | |
with cols[i]: | |
if image_data.get('thumbnail_url') and image_data.get('image_url'): | |
if st.button("More", key = f"more_{i}"): | |
st.session_state.selected_image = image_data | |
st.image(image_data['thumbnail_url'], width = 150) #display thumbnail | |
else: | |
st.image(image_data['image_url'], width = 150) | |
st.write(f"**Prompt:** {image_data['prompt']}") | |
st.write(f"**Aspect Ratio:** {image_data['aspect_ratio']}") | |
st.write(f"**Realism:** {image_data['realism']}") | |
st.markdown("---") | |
else: | |
st.write("No image generated yet!") | |
# Display modal if an image is selected | |
if st.session_state.selected_image: | |
with st.container(border = True): | |
st.image(st.session_state.selected_image['image_url'], use_column_width=True) | |
st.write(f"**Prompt:** {st.session_state.selected_image['prompt']}") | |
st.write(f"**Aspect Ratio:** {st.session_state.selected_image['aspect_ratio']}") | |
st.write(f"**Realism:** {st.session_state.selected_image['realism']}") | |
download_path = download_image(st.session_state.selected_image['image_url']) | |
if download_path: | |
st.download_button(label="Download Image", data = open(download_path, "rb"), file_name = f"image.png", key=f"download_overlay_{uuid.uuid4()}") | |
if st.button("Close"): | |
st.session_state.selected_image = None # close the modal when "close" is clicked | |
# Logout button | |
if st.button("Logout", on_click=logout_callback): | |
pass | |
if st.session_state.logged_in: | |
main_app() | |
else: | |
registration_form() | |
login_form() |