Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
import time | |
import os | |
import base64 | |
from io import BytesIO | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
HF_TOKEN_ERROR = "Hugging Face API token (HF_TOKEN) not found. Please set it as an environment variable or Gradio secret." | |
else: | |
HF_TOKEN_ERROR = None | |
client = InferenceClient(token=HF_TOKEN) | |
PROMPT_IMPROVER_MODEL = "Qwen/Qwen2.5-Coder-32B-Instruct" | |
def improve_prompt(original_prompt): | |
if HF_TOKEN_ERROR: | |
raise gr.Error(HF_TOKEN_ERROR) | |
try: | |
system_prompt = "You are a helpful assistant that improves text prompts for image generation models. Make the prompt more descriptive, detailed, and artistic, while keeping the user's original intent." | |
prompt_for_llm = f"""<|system|> | |
{system_prompt}</s> | |
<|user|> | |
Improve this prompt: {original_prompt} | |
</s> | |
<|assistant|> | |
""" | |
improved_prompt = client.text_generation( | |
prompt=prompt_for_llm, | |
model=PROMPT_IMPROVER_MODEL, | |
max_new_tokens=1280, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
stop_sequences=["</s>"], | |
) | |
return improved_prompt.strip() | |
except Exception as e: | |
print(f"Error improving prompt: {e}") | |
return original_prompt | |
def generate_image(prompt, progress=gr.Progress()): | |
if HF_TOKEN_ERROR: | |
raise gr.Error(HF_TOKEN_ERROR) | |
progress(0, desc="Improving prompt...") | |
improved_prompt = improve_prompt(prompt) | |
progress(0.2, desc="Sending request ") | |
try: | |
image = client.text_to_image(improved_prompt, model="black-forest-labs/FLUX.1-schnell") | |
if not isinstance(image, Image.Image): | |
raise Exception(f"Expected a PIL Image, but got: {type(image)}") | |
progress(0.8, desc="Processing image...") | |
time.sleep(0.5) | |
progress(1.0, desc="Done!") | |
return image | |
except Exception as e: | |
if "rate limit" in str(e).lower(): | |
error_message = f"Rate limit exceeded. Please try again later. Error: {e}" | |
else: | |
error_message = f"An error occurred: {e}" | |
raise gr.Error(error_message) | |
def pil_to_base64(img): | |
buffered = BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return f"data:image/png;base64,{img_str}" | |
css = """ | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# Xylaria Iris v3 | |
""", | |
elem_classes="title" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(elem_classes="input-section"): | |
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., A cat", lines=3) | |
generate_button = gr.Button("Generate Image", elem_classes="submit-button") | |
with gr.Column(): | |
with gr.Group(elem_classes="output-section") as output_group: | |
image_output = gr.Image(label="Generated Image", interactive=False) | |
def on_generate_click(prompt): | |
output_group.elem_classes = ["output-section", "animate"] | |
image = generate_image(prompt) # Ignore the improved prompt | |
output_group.elem_classes = ["output-section"] | |
return image # Return only the generated image | |
generate_button.click(on_generate_click, inputs=prompt_input, outputs=image_output) | |
prompt_input.submit(on_generate_click, inputs=prompt_input, outputs=image_output) | |
gr.Examples( | |
[["A dog"], | |
["A house on a hill"], | |
["A spaceship"]], | |
inputs=prompt_input | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |