iris / app.py
Reality123b's picture
Update app.py
db768d2 verified
import gradio as gr
from huggingface_hub import InferenceClient
from PIL import Image
import time
import os
import base64
from io import BytesIO
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret."
else:
HF_TOKEN_ERROR = None
client = InferenceClient(token=HF_TOKEN)
PROMPT_IMPROVER_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct"
def improve_prompt(original_prompt):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
try:
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent."
prompt_for_llm = f"""<|system|>
{system_prompt}</s>
<|user|>
Improve this prompt: {original_prompt}
</s>
<|assistant|>
"""
improved_prompt = client.text_generation(
prompt=prompt_for_llm,
model=PROMPT_IMPROVER_MODEL,
max_new_tokens=1280,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
stop_sequences=["</s>"],
)
return improved_prompt.strip()
except Exception as e:
print(f"Error improving prompt: {e}")
return original_prompt
def generate_image(prompt, progress=gr.Progress()):
if HF_TOKEN_ERROR:
raise gr.Error(HF_TOKEN_ERROR)
progress(0, desc="Improving prompt...")
improved_prompt = improve_prompt(prompt)
progress(0.2, desc="Sending request ")
try:
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell")
if not isinstance(image, Image.Image):
raise Exception(f"Expected a PIL Image, but got: {type(image)}")
progress(0.8, desc="Processing image...")
time.sleep(0.5)
progress(1.0, desc="Done!")
return image
except Exception as e:
if "rate limit" in str(e).lower():
error_message = f"Rate limit exceeded. Please try again later. Error: {e}"
else:
error_message = f"An error occurred: {e}"
raise gr.Error(error_message)
def pil_to_base64(img):
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
css = """
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Xylaria Iris v3
""",
elem_classes="title"
)
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="input-section"):
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3)
generate_button = gr.Button("Generate Image", elem_classes="submit-button")
with gr.Column():
with gr.Group(elem_classes="output-section") as output_group:
image_output = gr.Image(label="Generated Image", interactive=False)
def on_generate_click(prompt):
output_group.elem_classes = ["output-section", "animate"]
image = generate_image(prompt) # Ignore the improved prompt
output_group.elem_classes = ["output-section"]
return image # Return only the generated image
generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output)
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output)
gr.Examples(
[["A dog"],
["A house on a hill"],
["A spaceship"]],
inputs=prompt_input
)
if __name__ == "__main__":
demo.queue().launch()