|
import base64 |
|
import dash |
|
from dash import dcc, html, Input, Output, State |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
import google.generativeai as genai |
|
import requests |
|
import logging |
|
import threading |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
STYLES = [ |
|
"photographic", "3d-model", "analog-film", "anime", "cinematic", "comic-book", |
|
"digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", |
|
"modeling-compound", "neon-punk", "origami", "pixel-art", "tile-texture" |
|
] |
|
|
|
|
|
DEFAULT_NEGATIVE_PROMPT = """ |
|
ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, |
|
extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, |
|
cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face, |
|
plastic, cartoonish, artificial, fake, unnatural |
|
""" |
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
|
app.layout = dbc.Container([ |
|
html.H1("ImaGen", className="my-4"), |
|
dbc.Row([ |
|
|
|
dbc.Col([ |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dbc.Input(id="google-api-key", type="password", placeholder="Enter Google AI API Key", className="mb-3"), |
|
dbc.Input(id="stability-api-key", type="password", placeholder="Enter Stability AI API Key", className="mb-3"), |
|
dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"), |
|
dcc.Dropdown( |
|
id="style", |
|
options=[{"label": s.replace("-", " ").title(), "value": s} for s in STYLES], |
|
value="photographic", |
|
placeholder="Select style", |
|
className="mb-3" |
|
), |
|
dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"), |
|
dbc.Accordion([ |
|
dbc.AccordionItem( |
|
[ |
|
dbc.Label("Aspect Ratio"), |
|
dcc.Dropdown( |
|
id="aspect-ratio", |
|
options=[ |
|
{"label": ar, "value": ar} for ar in |
|
["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"] |
|
], |
|
value="1:1" |
|
), |
|
dbc.Label("Steps"), |
|
dcc.Slider(id="steps", min=4, max=50, step=1, value=20, marks={4: '4', 25: '25', 50: '50'}), |
|
], |
|
title="Advanced Settings", |
|
), |
|
], start_collapsed=True, className="mb-3"), |
|
]) |
|
], className="mb-4"), |
|
], width=6), |
|
|
|
dbc.Col([ |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dcc.Loading( |
|
id="loading", |
|
type="circle", |
|
children=[ |
|
html.Div(id="status-message", className="mb-3"), |
|
html.Img(id="image-output", className="img-fluid mb-3"), |
|
html.Div(id="enhanced-prompt-output", className="mb-3"), |
|
dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True), |
|
dcc.Download(id="download-image") |
|
] |
|
), |
|
]) |
|
]), |
|
], width=6), |
|
]), |
|
], fluid=True) |
|
|
|
def enhance_prompt(google_api_key, prompt, style): |
|
genai.configure(api_key=google_api_key) |
|
model = genai.GenerativeModel("gemini-2.0-flash-lite") |
|
enhanced_prompt_request = f""" |
|
Task: Enhance the following prompt with details to match the specified style |
|
Style: {style} |
|
Original prompt: '{prompt}' |
|
|
|
Instructions: |
|
1. Expand the prompt to be more detailed, vivid, and realistic with camera used and the setting for that camera like ISO etc. |
|
2. Incorporate elements of the specified style. |
|
3. Add details that enhance the scene to the specified style |
|
4. Emphasize natural lighting and enhance the realism of textures and colors based on the specified style. |
|
5. Avoid terms that might result in artificial or cartoonish appearance unless specified by user. |
|
6. Maintain the original intent of the prompt while significantly improving its descriptive quality with details. |
|
7. Provide ONLY the enhanced prompt, without any explanations or options. |
|
8. Keep the enhanced prompt concise, ideally under 100 words. |
|
|
|
Enhanced prompt: |
|
""" |
|
|
|
try: |
|
response = model.generate_content(enhanced_prompt_request) |
|
|
|
enhanced_prompt = response.text.strip() |
|
|
|
prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"] |
|
for prefix in prefixes_to_remove: |
|
if enhanced_prompt.lower().startswith(prefix.lower()): |
|
enhanced_prompt = enhanced_prompt[len(prefix):].strip() |
|
|
|
logging.info(f"Enhanced prompt: {enhanced_prompt}") |
|
return enhanced_prompt |
|
except Exception as e: |
|
logging.error(f"Error in enhance_prompt: {str(e)}") |
|
raise |
|
|
|
def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, steps=30, aspect_ratio="1:1"): |
|
url = "https://api.stability.ai/v2beta/stable-image/generate/sd3" |
|
|
|
headers = { |
|
"Accept": "image/*", |
|
"Authorization": f"Bearer {stability_api_key}" |
|
} |
|
|
|
data = { |
|
"prompt": f"{enhanced_prompt}, Style: {style}, highly detailed, high quality, descriptive", |
|
"negative_prompt": negative_prompt, |
|
"model": "sd3.5-large-turbo", |
|
"output_format": "jpeg", |
|
"num_images": 1, |
|
"steps": steps, |
|
|
|
"style_preset": style, |
|
"aspect_ratio": aspect_ratio, |
|
} |
|
|
|
try: |
|
response = requests.post(url, headers=headers, files={"none": ''}, data=data, timeout=60) |
|
response.raise_for_status() |
|
|
|
logging.debug(f"Response headers: {response.headers}") |
|
logging.debug(f"Response content type: {response.headers.get('content-type')}") |
|
|
|
if response.headers.get('content-type').startswith('image/'): |
|
image_data = response.content |
|
if len(image_data) < 1000: |
|
raise Exception("Received incomplete image data") |
|
return image_data |
|
else: |
|
error_message = response.text |
|
logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") |
|
raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") |
|
|
|
except requests.exceptions.RequestException as e: |
|
logging.error(f"Request failed: {str(e)}") |
|
raise Exception(f"Request failed: {str(e)}") |
|
|
|
def process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status): |
|
try: |
|
set_status("Enhancing prompt...") |
|
enhanced_prompt = enhance_prompt(google_api_key, prompt, style) |
|
|
|
set_status("Generating image...") |
|
max_attempts = 3 |
|
for attempt in range(max_attempts): |
|
try: |
|
image_bytes = generate_image(stability_api_key, enhanced_prompt, style, DEFAULT_NEGATIVE_PROMPT, steps, aspect_ratio) |
|
set_status("Image generated successfully!") |
|
return image_bytes, enhanced_prompt |
|
except Exception as e: |
|
if attempt < max_attempts - 1: |
|
set_status(f"Attempt {attempt + 1} failed. Retrying...") |
|
time.sleep(2) |
|
else: |
|
raise e |
|
except Exception as e: |
|
logging.error(f"Error in process_and_generate: {str(e)}") |
|
set_status(f"Error: {str(e)}") |
|
return None, str(e) |
|
|
|
@app.callback( |
|
[Output("image-output", "src"), |
|
Output("enhanced-prompt-output", "children"), |
|
Output("status-message", "children"), |
|
Output("download-btn", "disabled")], |
|
[Input("submit-btn", "n_clicks")], |
|
[State("google-api-key", "value"), |
|
State("stability-api-key", "value"), |
|
State("prompt", "value"), |
|
State("style", "value"), |
|
State("steps", "value"), |
|
State("aspect-ratio", "value")], |
|
prevent_initial_call=True |
|
) |
|
def update_output(n_clicks, google_api_key, stability_api_key, prompt, style, steps, aspect_ratio): |
|
if n_clicks is None: |
|
raise PreventUpdate |
|
|
|
logging.debug(f"Stability API Key (first 4 chars): {stability_api_key[:4]}...") |
|
|
|
status = {"message": "Starting process..."} |
|
|
|
def set_status(message): |
|
status["message"] = message |
|
|
|
def run_process(): |
|
image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status) |
|
if image_bytes: |
|
encoded_image = base64.b64encode(image_bytes).decode('ascii') |
|
return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False |
|
else: |
|
return "", f"Error: {enhanced_prompt}", status["message"], True |
|
|
|
try: |
|
thread = threading.Thread(target=run_process) |
|
thread.start() |
|
thread.join(timeout=90) |
|
|
|
if thread.is_alive(): |
|
return "", "Error: Image generation timed out", "Process timed out", True |
|
|
|
return run_process() |
|
except Exception as e: |
|
logging.error(f"Unexpected error in update_output: {str(e)}") |
|
return "", f"Unexpected error: {str(e)}", "An unexpected error occurred", True |
|
|
|
@app.callback( |
|
Output("download-image", "data"), |
|
Input("download-btn", "n_clicks"), |
|
State("image-output", "src"), |
|
prevent_initial_call=True |
|
) |
|
def download_image(n_clicks, image_src): |
|
if n_clicks is None: |
|
raise PreventUpdate |
|
|
|
image_data = image_src.split(",")[1] |
|
image_bytes = base64.b64decode(image_data) |
|
|
|
return dcc.send_bytes(image_bytes, "generated_image.jpeg") |
|
|
|
if __name__ == '__main__': |
|
print("Starting the Dash application...") |
|
app.run(debug=False, host='0.0.0.0', port=7860) |
|
print("Dash application has finished running.") |