text2image_3 / app.py
RanM's picture
Update app.py
0f197c4 verified
raw
history blame
3.01 kB
import os
from io import BytesIO
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
from generate_prompts import generate_prompt
# Load the model once at the start
print("Loading the Stable Diffusion model...")
model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
print("Model loaded successfully.")
def truncate_prompt(prompt, max_length=77):
tokens = prompt.split()
if len(tokens) > max_length:
prompt = " ".join(tokens[:max_length])
return prompt
def generate_image(prompt):
try:
truncated_prompt = truncate_prompt(prompt)
print(f"Generating image with truncated prompt: {truncated_prompt}")
# Call the model
output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
# Check if output is valid
if output is not None and hasattr(output, 'images') and output.images:
print(f"Image generated")
image = output.images[0]
return image, None
else:
print(f"No images found or generated output is None")
return None, "No images found or generated output is None"
except Exception as e:
print(f"An error occurred while generating image: {e}")
return None, str(e)
def inference(prompt):
print(f"Received prompt: {prompt}") # Debugging statement
image, error = generate_image(prompt)
if error:
print(f"Error generating image: {error}") # Debugging statement
return "Error: " + error
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def process_prompt(sentence_mapping, character_dict, selected_style):
print("Processing prompt...")
print(f"Sentence Mapping: {sentence_mapping}")
print(f"Character Dict: {character_dict}")
print(f"Selected Style: {selected_style}")
prompts = []
for paragraph_number, sentences in sentence_mapping.items():
combined_sentence = " ".join(sentences)
prompt, negative_prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
prompts.append((paragraph_number, prompt))
print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
images = {}
for paragraph_number, prompt in prompts:
img_str = inference(prompt)
images[paragraph_number] = img_str
print("Prompt processing complete. Generated images: ", images)
return images
gradio_interface = gr.Interface(
fn=process_prompt,
inputs=[
gr.JSON(label="Sentence Mapping"),
gr.JSON(label="Character Dict"),
gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
],
outputs="json"
).queue(default_concurrency_limit=20) # Set concurrency limit if needed
if __name__ == "__main__":
print("Launching Gradio interface...")
gradio_interface.launch()