Spaces:
Running
Running
import replicate | |
from PIL import Image | |
import io | |
import requests | |
import base64 | |
def generate_image( | |
prompt, | |
num_steps=30, | |
guidance_scale=7.5, | |
aspect_ratio="1:1", | |
replicate_api_key=None, | |
lora_url=None, | |
negative_prompt=None | |
): | |
""" | |
Generate an image using Stable Diffusion via Replicate API | |
Args: | |
prompt (str): The text prompt for image generation | |
num_steps (int): Number of inference steps | |
guidance_scale (float): Guidance scale for generation | |
aspect_ratio (str): Desired aspect ratio ("1:1", "16:9", "3:2", etc.) | |
replicate_api_key (str): API key for Replicate | |
lora_url (str, optional): URL to LoRA weights | |
negative_prompt (str, optional): Negative prompt for generation | |
""" | |
try: | |
if not replicate_api_key: | |
return None, "Please provide a Replicate API key" | |
# Set up aspect ratio dimensions | |
aspect_ratios = { | |
"1:1": (512, 512), | |
"16:9": (912, 512), | |
"3:2": (768, 512), | |
"2:3": (512, 768), | |
"4:5": (512, 640), | |
"5:4": (640, 512) | |
} | |
width, height = aspect_ratios.get(aspect_ratio, (512, 512)) | |
# Configure model parameters | |
model_params = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt or "ugly, blurry, low quality, distorted, deformed", | |
"num_inference_steps": num_steps, | |
"guidance_scale": guidance_scale, | |
"width": width, | |
"height": height, | |
"scheduler": "DPMSolverMultistep", # You can experiment with different schedulers | |
"num_outputs": 1, | |
} | |
# Add LoRA if specified | |
if lora_url: | |
model_params["lora_urls"] = lora_url | |
# Set API key | |
client = replicate.Client(api_token=replicate_api_key) | |
# Run the model | |
# Using SDXL model for better quality | |
output = client.run( | |
"stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | |
input=model_params | |
) | |
# Get the image URL from output | |
if output and len(output) > 0: | |
image_url = output[0] | |
# Download and convert to PIL Image | |
response = requests.get(image_url) | |
if response.status_code == 200: | |
image = Image.open(io.BytesIO(response.content)) | |
return image, "Success" | |
else: | |
return None, f"Failed to download image: {response.status_code}" | |
else: | |
return None, "No image generated" | |
except Exception as e: | |
return None, f"Error generating image: {str(e)}" | |
def encode_image_to_base64(image): | |
"""Helper function to convert PIL Image to base64 string""" | |
if isinstance(image, Image.Image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
return None |