pranavajay commited on
Commit
5322ffd
1 Parent(s): d091be2

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +443 -0
main.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import boto3
4
+ import random
5
+ import string
6
+ import numpy as np
7
+ import logging
8
+ import datetime
9
+ from fastapi import FastAPI, HTTPException, Request, Response
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, constr, conint
12
+ from diffusers import (FluxPipeline, FluxControlNetPipeline,
13
+ FluxControlNetModel, FluxImg2ImgPipeline,
14
+ FluxInpaintPipeline, CogVideoXImageToVideoPipeline)
15
+ from diffusers.utils import load_image
16
+ from PIL import Image
17
+
18
+ # Setup logging
19
+ logging.basicConfig(level=logging.INFO,
20
+ format='%(asctime)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.FileHandler("error.txt"),
23
+ logging.StreamHandler()
24
+ ])
25
+
26
+ app = FastAPI()
27
+
28
+ # Allow CORS for specific origins if needed
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"], # Update with specific domains as necessary
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+
39
+ # AWS S3 Configuration
40
+ AWS_ACCESS_KEY_ID = "your-access-key-id"
41
+ AWS_SECRET_ACCESS_KEY = "your-secret-access-key"
42
+ AWS_REGION = "your-region"
43
+ S3_BUCKET_NAME = "your-bucket-name"
44
+
45
+ # Initialize S3 client
46
+ s3_client = boto3.client(
47
+ 's3',
48
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
49
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
50
+ region_name=AWS_REGION
51
+ )
52
+
53
+ def log_requests(user_key: str, prompt: str):
54
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
55
+ log_entry = f"{timestamp}, {user_key}, {prompt}\n"
56
+ with open("key_requests.txt", "a") as log_file:
57
+ log_file.write(log_entry)
58
+
59
+ # Function to upload image to S3
60
+ def upload_image_to_s3(image_path: str, s3_path: str):
61
+ try:
62
+ s3_client.upload_file(image_path, S3_BUCKET_NAME, s3_path)
63
+ return f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{s3_path}"
64
+ except Exception as e:
65
+ logging.error(f"Error uploading image to S3: {e}")
66
+ raise HTTPException(status_code=500, detail=f"Image upload failed: {str(e)}")
67
+
68
+ # Generate a random sequence of 12 numbers and 11 words
69
+ def generate_random_sequence():
70
+ random_numbers = ''.join(random.choices(string.digits, k=12)) # 12 random digits
71
+ random_words = ''.join(random.choices(string.ascii_lowercase, k=11)) # 11 random letters
72
+ return f"{random_numbers}_{random_words}"
73
+
74
+ # Load the default pipeline once globally for efficiency
75
+
76
+ # Load the default pipeline once globally for efficiency
77
+ try:
78
+ flux_pipe = FluxPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
79
+ flux_pipe.enable_model_cpu_offload()
80
+ logging.info("FluxPipeline loaded successfully.")
81
+ except Exception as e:
82
+ logging.error(f"Failed to load FluxPipeline: {e}")
83
+ raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}")
84
+
85
+ try:
86
+ img_pipe = FluxImg2ImgPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
87
+ img_pipe.enable_model_cpu_offload()
88
+ logging.info("FluxImg2ImgPipeline loaded successfully.")
89
+ except Exception as e:
90
+ logging.error(f"Failed to load FluxPipeline: {e}")
91
+ raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}")
92
+
93
+ try:
94
+ inpainting_pipe = FluxInpaintPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16)
95
+ inpainting_pipe.enable_model_cpu_offload()
96
+ logging.info("FluxInpaintPipeline loaded successfully.")
97
+ except Exception as e:
98
+ logging.error(f"Failed to load FluxInpaintPipeline: {e}")
99
+ raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}")
100
+
101
+ try:
102
+ video = CogVideoXImageToVideoPipeline.from_pretrained(
103
+ "THUDM/CogVideoX-5b-I2V",
104
+ torch_dtype=torch.bfloat16
105
+ )
106
+ video.enable_sequential_cpu_offload()
107
+ video.vae.enable_tiling()
108
+ video.vae.enable_slicing()
109
+ logging.info("CogVideoXImageToVideoPipeline loaded successfully.")
110
+ except Exception as e:
111
+ logging.error(f"Failed to load CogVideoXImageToVideoPipeline: {e}")
112
+ raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}")
113
+
114
+
115
+ flux_controlnet_pipe = None
116
+
117
+
118
+
119
+ # Rate limiting variables
120
+ request_timestamps = defaultdict(list) # Store timestamps of requests per user key
121
+ RATE_LIMIT = 30 # Maximum requests allowed
122
+ TIME_WINDOW = 5 # Time window in seconds
123
+
124
+ # Available LoRA styles and ControlNet adapters
125
+ style_lora_mapping = {
126
+ "Uncensored": {"path": "enhanceaiteam/Flux-uncensored", "triggered_word": "nsfw"},
127
+ "Logo": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "triggered_word": "logo"},
128
+ "Yarn": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-MiaoKa-Yarn-World", "triggered_word": "mkym this is made of wool"},
129
+ "Anime": {"path": "prithivMLmods/Canopus-LoRA-Flux-Anime", "triggered_word": "anime"},
130
+ "Comic": {"path": "wkplhc/comic", "triggered_word": "comic"}
131
+ }
132
+
133
+ adapter_controlnet_mapping = {
134
+ "Canny": "InstantX/FLUX.1-dev-controlnet-canny",
135
+ "Depth": "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
136
+ "Pose": "Shakker-Labs/FLUX.1-dev-ControlNet-Pose",
137
+ "Upscale": "jasperai/Flux.1-dev-Controlnet-Upscaler"
138
+ }
139
+
140
+ # Request model for query parameters
141
+ class GenerateImageRequest(BaseModel):
142
+ prompt: constr(min_length=1) # Ensures prompt is not empty
143
+ guidance_scale: float = 7.5
144
+ seed: conint(ge=0, le=MAX_SEED) = 42
145
+ randomize_seed: bool = False
146
+ height: conint(gt=0) = 768
147
+ width: conint(gt=0) = 1360
148
+ control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png"
149
+ controlnet_conditioning_scale: float = 0.6
150
+ num_inference_steps: conint(gt=0) = 50
151
+ num_images_per_prompt: conint(gt=0, le=5) = 1 # Limit to max 5 images per request
152
+ style: str = None # Optional LoRA style
153
+ adapter: str = None # Optional ControlNet adapter
154
+ user_key: str # API user key
155
+
156
+ def log_request(key: str, query: str):
157
+ with open("key.txt", "a") as f:
158
+ f.write(f"{datetime.datetime.now()} - Key: {key} - Query: {query}\n")
159
+
160
+ def apply_lora_style(pipe, style, prompt):
161
+ """ Apply the specified LoRA style to the prompt and load weights. """
162
+ if style in style_lora_mapping:
163
+ lora_path = style_lora_mapping[style]["path"]
164
+ triggered_word = style_lora_mapping[style]["triggered_word"]
165
+ pipe.load_lora_weights(lora_path)
166
+ return f"{triggered_word} {prompt}" # Add triggered word to prompt
167
+ return prompt
168
+
169
+ def set_controlnet_adapter(adapter: str, is_inpainting: bool = False):
170
+ """
171
+ Set the ControlNet adapter for the pipeline.
172
+
173
+ Parameters:
174
+ adapter (str): The key to identify which ControlNet adapter to load.
175
+ is_inpainting (bool, optional): Whether to use the inpainting pipeline. Defaults to False.
176
+
177
+ Raises:
178
+ ValueError: If the adapter is not found in the adapter_controlnet_mapping.
179
+ """
180
+ global flux_controlnet_pipe
181
+
182
+ # Check if the adapter is valid
183
+ if adapter not in adapter_controlnet_mapping:
184
+ raise ValueError(f"Invalid ControlNet adapter: {adapter}")
185
+
186
+ # Get the ControlNet model path based on the adapter
187
+ controlnet_model_path = adapter_controlnet_mapping[adapter]
188
+
189
+ # Load the ControlNet model with the specified torch_dtype
190
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
191
+
192
+ # Select the appropriate pipeline (inpainting or standard)
193
+ pipeline_cls = FluxControlNetInpaintPipeline if is_inpainting else FluxControlNetPipeline
194
+
195
+ # Load the pipeline
196
+ flux_controlnet_pipe = pipeline_cls.from_pretrained(
197
+ "pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16
198
+ )
199
+
200
+ # Move the pipeline to the GPU
201
+ flux_controlnet_pipe.to("cuda")
202
+
203
+ logging.info(f"ControlNet adapter '{adapter}' loaded successfully.")
204
+
205
+
206
+
207
+
208
+
209
+
210
+
211
+ def rate_limit(user_key: str):
212
+ """ Check if the user is exceeding the rate limit. """
213
+ current_time = time.time()
214
+
215
+ # Clean up old timestamps
216
+ request_timestamps[user_key] = [t for t in request_timestamps[user_key] if current_time - t < TIME_WINDOW]
217
+
218
+ if len(request_timestamps[user_key]) >= RATE_LIMIT:
219
+ logging.info(f"Rate limit exceeded for user_key: {user_key}")
220
+ return False
221
+
222
+ # Record the new request timestamp
223
+ request_timestamps[user_key].append(current_time)
224
+ return True
225
+
226
+ @app.post("/text_to_image/")
227
+ async def generate_image(req: GenerateImageRequest):
228
+ seed = req.seed
229
+ if not rate_limit(req.user_key):
230
+ log_requests(req.user_key, req.prompt) # Log the request when rate limit is exceeded
231
+
232
+ retries = 3 # Number of retries for transient errors
233
+
234
+ for attempt in range(retries):
235
+ try:
236
+ # Check if prompt is None or empty
237
+ if not req.prompt or req.prompt.strip() == "":
238
+ raise ValueError("Prompt cannot be empty.")
239
+
240
+ original_prompt = req.prompt # Save the original prompt
241
+
242
+ # Set ControlNet if adapter is provided
243
+ if req.adapter:
244
+ try:
245
+ set_controlnet_adapter(req.adapter)
246
+ except Exception as e:
247
+ logging.error(f"Error setting ControlNet adapter: {e}")
248
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
249
+ apply_lora_style(flux_controlnet_pipe, req.style, req.prompt)
250
+
251
+
252
+ # Load control image
253
+ try:
254
+ control_image = load_image(req.control_image_url)
255
+ except Exception as e:
256
+ logging.error(f"Error loading control image from URL: {e}")
257
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
258
+
259
+ # Image generation with ControlNet
260
+ try:
261
+ if req.randomize_seed:
262
+ seed = random.randint(0, MAX_SEED)
263
+ generator = torch.Generator().manual_seed(seed)
264
+
265
+ images = flux_controlnet_pipe(
266
+ prompt=modified_prompt,
267
+ guidance_scale=req.guidance_scale,
268
+ height=req.height,
269
+ width=req.width,
270
+ num_inference_steps=req.num_inference_steps,
271
+ num_images_per_prompt=req.num_images_per_prompt,
272
+ control_image=control_image,
273
+ generator=generator,
274
+ controlnet_conditioning_scale=req.controlnet_conditioning_scale
275
+ ).images
276
+ except torch.cuda.OutOfMemoryError:
277
+ logging.error("GPU out of memory error while generating images with ControlNet.")
278
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
279
+ except Exception as e:
280
+ logging.error(f"Error during image generation with ControlNet: {e}")
281
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
282
+ else:
283
+ # Image generation without ControlNet
284
+ try:
285
+ apply_lora_style(flux_pipe, req.style, req.prompt)
286
+ if req.randomize_seed:
287
+ seed = random.randint(0, MAX_SEED)
288
+ generator = torch.Generator().manual_seed(seed)
289
+
290
+ images = flux_pipe(
291
+ prompt=modified_prompt,
292
+ guidance_scale=req.guidance_scale,
293
+ height=req.height,
294
+ width=req.width,
295
+ num_inference_steps=req.num_inference_steps,
296
+ num_images_per_prompt=req.num_images_per_prompt,
297
+ generator=generator
298
+ ).images
299
+ except torch.cuda.OutOfMemoryError:
300
+ logging.error("GPU out of memory error while generating images without ControlNet.")
301
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
302
+ except Exception as e:
303
+ logging.error(f"Error during image generation without ControlNet: {e}")
304
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
305
+
306
+ # Saving images and uploading to S3
307
+ image_urls = []
308
+ for i, img in enumerate(images):
309
+ image_path = f"generated_images/{generate_random_sequence()}.png"
310
+ img.save(image_path)
311
+ image_url = upload_image_to_s3(image_path, image_path)
312
+ image_urls.append(image_url)
313
+ os.remove(image_path) # Clean up local files after upload
314
+
315
+ return {"status": "success", "output": image_url, "prompt": original_prompt, "height": req.height, "width": req.width, "scale": req.guidance_scale, "step": step, "sytle": req.sytle, "adapter": req.adapter}
316
+
317
+ except Exception as e:
318
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
319
+ if attempt == retries - 1: # Last attempt
320
+ raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}")
321
+ continue # Retry on transient errors
322
+
323
+
324
+ # Image-to-Image request model
325
+ class GenerateImageToImageRequest(BaseModel):
326
+ prompt: str = None # Prompt can be None
327
+ image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
328
+ strength: float = 0.7
329
+ guidance_scale: float = 7.5
330
+ seed: conint(ge=0, le=MAX_SEED) = 42
331
+ randomize_seed: bool = False
332
+ height: conint(gt=0) = 768
333
+ width: conint(gt=0) = 1360
334
+ control_image_url: str = None # Optional ControlNet image
335
+ controlnet_conditioning_scale: float = 0.6
336
+ num_inference_steps: conint(gt=0) = 50
337
+ num_images_per_prompt: conint(gt=0, le=5) = 1
338
+ style: str = None # Optional LoRA style
339
+ adapter: str = None # Optional ControlNet adapter
340
+ user_key: str # API user key
341
+
342
+ @app.post("/image_to_image/")
343
+ async def generate_image_to_image(req: GenerateImageToImageRequest):
344
+ seed = req.seed
345
+ original_prompt = req.prompt
346
+ modified_prompt = original_prompt
347
+
348
+ # Check if user is exceeding rate limit
349
+ if not rate_limit(req.user_key):
350
+ log_requests(req.user_key, req.prompt if req.prompt else "No prompt")
351
+ raise HTTPException(status_code=429, detail="Rate limit exceeded")
352
+
353
+ retries = 3 # Number of retries for transient errors
354
+
355
+ for attempt in range(retries):
356
+ try:
357
+ # Check if prompt is None or empty
358
+ if not req.prompt or req.prompt.strip() == "":
359
+ raise ValueError("Prompt cannot be empty.")
360
+
361
+ original_prompt = req.prompt # Save the original prompt
362
+
363
+ # Set ControlNet if adapter is provided
364
+ if req.adapter:
365
+ try:
366
+ set_controlnet_adapter(req.adapter)
367
+ except Exception as e:
368
+ logging.error(f"Error setting ControlNet adapter: {e}")
369
+ raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}")
370
+ apply_lora_style(flux_controlnet_pipe, req.style, req.prompt)
371
+
372
+
373
+ # Load control image
374
+ try:
375
+ control_image = load_image(req.control_image_url)
376
+ except Exception as e:
377
+ logging.error(f"Error loading control image from URL: {e}")
378
+ raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.")
379
+
380
+ # Image generation with ControlNet
381
+ try:
382
+ if req.randomize_seed:
383
+ seed = random.randint(0, MAX_SEED)
384
+ generator = torch.Generator().manual_seed(seed)
385
+
386
+ images = flux_controlnet_pipe(
387
+ prompt=modified_prompt,
388
+ guidance_scale=req.guidance_scale,
389
+ height=req.height,
390
+ width=req.width,
391
+ num_inference_steps=req.num_inference_steps,
392
+ num_images_per_prompt=req.num_images_per_prompt,
393
+ control_image=control_image,
394
+ generator=generator,
395
+ controlnet_conditioning_scale=req.controlnet_conditioning_scale
396
+ ).images
397
+ except torch.cuda.OutOfMemoryError:
398
+ logging.error("GPU out of memory error while generating images with ControlNet.")
399
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
400
+ except Exception as e:
401
+ logging.error(f"Error during image generation with ControlNet: {e}")
402
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
403
+ else:
404
+ # Image generation without ControlNet
405
+ try:
406
+ apply_lora_style(img_pipe, req.style, req.prompt)
407
+ if req.randomize_seed:
408
+ seed = random.randint(0, MAX_SEED)
409
+ generator = torch.Generator().manual_seed(seed)
410
+ source = load_image(req.image)
411
+ images = img_pipe(
412
+ prompt=modified_prompt,
413
+ image=source,
414
+ strength=req.strength,
415
+ guidance_scale=req.guidance_scale,
416
+ height=req.height,
417
+ width=req.width,
418
+ num_inference_steps=req.num_inference_steps,
419
+ num_images_per_prompt=req.num_images_per_prompt,
420
+ generator=generator
421
+ ).images
422
+ except torch.cuda.OutOfMemoryError:
423
+ logging.error("GPU out of memory error while generating images without ControlNet.")
424
+ raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.")
425
+ except Exception as e:
426
+ logging.error(f"Error during image generation without ControlNet: {e}")
427
+ raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}")
428
+
429
+ # Saving images and uploading to S3
430
+ image_urls = []
431
+ for i, img in enumerate(images):
432
+ image_path = f"generated_images/{generate_random_sequence()}.png"
433
+ img.save(image_path)
434
+ image_url = upload_image_to_s3(image_path, image_path)
435
+ image_urls.append(image_url)
436
+ os.remove(image_path) # Clean up local files after upload
437
+
438
+ return {"status": "success", "output": image_url, "prompt": original_prompt, "height": req.height, "width": width, "image": req.image, "strength": req.strength, "scale": req.guidance_scale, "step": step, "sytle": req.sytle, "adapter": req.adapter}
439
+
440
+ except Exception as e:
441
+ logging.error(f"Attempt {attempt + 1} failed: {e}")
442
+ if attempt == retries - 1: # Last attempt
443
+ raise HTTPException(status_code=500, detail=f"Failed to generate image after m