Spaces:
Sleeping
Sleeping
| import os | |
| from io import BytesIO | |
| from diffusers import AutoPipelineForText2Image | |
| import gradio as gr | |
| import base64 | |
| from generate_prompts import generate_prompt | |
| # Load the model once at the start | |
| print("Loading the Stable Diffusion model...") | |
| model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") | |
| print("Model loaded successfully.") | |
| def truncate_prompt(prompt, max_length=77): | |
| tokens = prompt.split() | |
| if len(tokens) > max_length: | |
| prompt = " ".join(tokens[:max_length]) | |
| return prompt | |
| def generate_image(prompt): | |
| try: | |
| truncated_prompt = truncate_prompt(prompt) | |
| print(f"Generating image with truncated prompt: {truncated_prompt}") | |
| # Call the model | |
| output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0) | |
| # Check if output is valid | |
| if output is not None and hasattr(output, 'images') and output.images: | |
| print(f"Image generated") | |
| image = output.images[0] | |
| return image, None | |
| else: | |
| print(f"No images found or generated output is None") | |
| return None, "No images found or generated output is None" | |
| except Exception as e: | |
| print(f"An error occurred while generating image: {e}") | |
| return None, str(e) | |
| def inference(prompt): | |
| print(f"Received prompt: {prompt}") # Debugging statement | |
| image, error = generate_image(prompt) | |
| if error: | |
| print(f"Error generating image: {error}") # Debugging statement | |
| return "Error: " + error | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return img_str | |
| def process_prompt(sentence_mapping, character_dict, selected_style): | |
| print("Processing prompt...") | |
| print(f"Sentence Mapping: {sentence_mapping}") | |
| print(f"Character Dict: {character_dict}") | |
| print(f"Selected Style: {selected_style}") | |
| prompts = [] | |
| for paragraph_number, sentences in sentence_mapping.items(): | |
| combined_sentence = " ".join(sentences) | |
| prompt, negative_prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style) | |
| prompts.append((paragraph_number, prompt)) | |
| print(f"Generated prompt for paragraph {paragraph_number}: {prompt}") | |
| images = {} | |
| for paragraph_number, prompt in prompts: | |
| img_str = inference(prompt) | |
| images[paragraph_number] = img_str | |
| print("Prompt processing complete. Generated images: ", images) | |
| return images | |
| gradio_interface = gr.Interface( | |
| fn=process_prompt, | |
| inputs=[ | |
| gr.JSON(label="Sentence Mapping"), | |
| gr.JSON(label="Character Dict"), | |
| gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style") | |
| ], | |
| outputs="json" | |
| ).queue(default_concurrency_limit=20) # Set concurrency limit if needed | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| gradio_interface.launch() | |