import os from io import BytesIO from PIL import Image from diffusers import AutoPipelineForText2Image import gradio as gr import base64 from generate_prompts import generate_prompt CONCURRENCY_LIMIT = 10 def load_model(): print("Loading the Stable Diffusion model...") try: model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") print("Model loaded successfully.") return model except Exception as e: print(f"Error loading model: {e}") return None def generate_image(prompt): model = load_model() try: if model is None: raise ValueError("Model not loaded properly.") print(f"Generating image with prompt: {prompt}") output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0) print(f"Model output: {output}") if output is None: raise ValueError("Model returned None") if hasattr(output, 'images') and output.images: print(f"Image generated successfully") image = output.images[0] buffered = BytesIO() image.save(buffered, format="JPEG") image_bytes = buffered.getvalue() img_str = base64.b64encode(image_bytes).decode("utf-8") print("Image encoded to base64") print(f'img_str: {img_str[:100]}...') # Print a snippet of the base64 string return img_str, None else: print(f"No images found in model output") raise ValueError("No images found in model output") except Exception as e: print(f"An error occurred while generating image: {e}") return None, str(e) def inference(sentence_mapping, character_dict, selected_style): try: print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}") print(f"Received character_dict: {character_dict}, type: {type(character_dict)}") print(f"Received selected_style: {selected_style}, type: {type(selected_style)}") images = {} for paragraph_number, sentences in sentence_mapping.items(): combined_sentence = " ".join(sentences) prompt = generate_prompt(combined_sentence,character_dict, selected_style) print(f"Generated prompt for paragraph {paragraph_number}: {prompt}") img_str, error = generate_image(prompt) if error: images[paragraph_number] = f"Error: {error}" else: images[paragraph_number] = img_str return images except Exception as e: print(f"An error occurred during inference: {e}") return {"error": str(e)} gradio_interface = gr.Interface( fn=inference, inputs=[ gr.JSON(label="Sentence Mapping"), gr.JSON(label="Character Dict"), gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style") ], outputs="json", concurrency_limit=CONCURRENCY_LIMIT) if __name__ == "__main__": print("Launching Gradio interface...") gradio_interface.launch()