Abhlash commited on
Commit
cb9b654
1 Parent(s): bca57e9

Create image_generation.py

Browse files
Files changed (1) hide show
  1. image_generation.py +92 -0
image_generation.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import replicate
2
+ from PIL import Image
3
+ import io
4
+ import requests
5
+ import base64
6
+
7
+ def generate_image(
8
+ prompt,
9
+ num_steps=30,
10
+ guidance_scale=7.5,
11
+ aspect_ratio="1:1",
12
+ replicate_api_key=None,
13
+ lora_url=None,
14
+ negative_prompt=None
15
+ ):
16
+ """
17
+ Generate an image using Stable Diffusion via Replicate API
18
+
19
+ Args:
20
+ prompt (str): The text prompt for image generation
21
+ num_steps (int): Number of inference steps
22
+ guidance_scale (float): Guidance scale for generation
23
+ aspect_ratio (str): Desired aspect ratio ("1:1", "16:9", "3:2", etc.)
24
+ replicate_api_key (str): API key for Replicate
25
+ lora_url (str, optional): URL to LoRA weights
26
+ negative_prompt (str, optional): Negative prompt for generation
27
+ """
28
+ try:
29
+ if not replicate_api_key:
30
+ return None, "Please provide a Replicate API key"
31
+
32
+ # Set up aspect ratio dimensions
33
+ aspect_ratios = {
34
+ "1:1": (512, 512),
35
+ "16:9": (912, 512),
36
+ "3:2": (768, 512),
37
+ "2:3": (512, 768),
38
+ "4:5": (512, 640),
39
+ "5:4": (640, 512)
40
+ }
41
+ width, height = aspect_ratios.get(aspect_ratio, (512, 512))
42
+
43
+ # Configure model parameters
44
+ model_params = {
45
+ "prompt": prompt,
46
+ "negative_prompt": negative_prompt or "ugly, blurry, low quality, distorted, deformed",
47
+ "num_inference_steps": num_steps,
48
+ "guidance_scale": guidance_scale,
49
+ "width": width,
50
+ "height": height,
51
+ "scheduler": "DPMSolverMultistep", # You can experiment with different schedulers
52
+ "num_outputs": 1,
53
+ }
54
+
55
+ # Add LoRA if specified
56
+ if lora_url:
57
+ model_params["lora_urls"] = lora_url
58
+
59
+ # Set API key
60
+ client = replicate.Client(api_token=replicate_api_key)
61
+
62
+ # Run the model
63
+ # Using SDXL model for better quality
64
+ output = client.run(
65
+ "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
66
+ input=model_params
67
+ )
68
+
69
+ # Get the image URL from output
70
+ if output and len(output) > 0:
71
+ image_url = output[0]
72
+
73
+ # Download and convert to PIL Image
74
+ response = requests.get(image_url)
75
+ if response.status_code == 200:
76
+ image = Image.open(io.BytesIO(response.content))
77
+ return image, "Success"
78
+ else:
79
+ return None, f"Failed to download image: {response.status_code}"
80
+ else:
81
+ return None, "No image generated"
82
+
83
+ except Exception as e:
84
+ return None, f"Error generating image: {str(e)}"
85
+
86
+ def encode_image_to_base64(image):
87
+ """Helper function to convert PIL Image to base64 string"""
88
+ if isinstance(image, Image.Image):
89
+ buffered = io.BytesIO()
90
+ image.save(buffered, format="PNG")
91
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
92
+ return None