|
import os |
|
import torch |
|
import boto3 |
|
import random |
|
import string |
|
import numpy as np |
|
import logging |
|
import datetime |
|
from fastapi import FastAPI, HTTPException, Request, Response |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel, constr, conint |
|
from diffusers import (FluxPipeline, FluxControlNetPipeline, |
|
FluxControlNetModel, FluxImg2ImgPipeline, |
|
FluxInpaintPipeline, CogVideoXImageToVideoPipeline) |
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler("error.txt"), |
|
logging.StreamHandler() |
|
]) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
AWS_ACCESS_KEY_ID = "your-access-key-id" |
|
AWS_SECRET_ACCESS_KEY = "your-secret-access-key" |
|
AWS_REGION = "your-region" |
|
S3_BUCKET_NAME = "your-bucket-name" |
|
|
|
|
|
s3_client = boto3.client( |
|
's3', |
|
aws_access_key_id=AWS_ACCESS_KEY_ID, |
|
aws_secret_access_key=AWS_SECRET_ACCESS_KEY, |
|
region_name=AWS_REGION |
|
) |
|
|
|
def log_requests(user_key: str, prompt: str): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
log_entry = f"{timestamp}, {user_key}, {prompt}\n" |
|
with open("key_requests.txt", "a") as log_file: |
|
log_file.write(log_entry) |
|
|
|
|
|
def upload_image_to_s3(image_path: str, s3_path: str): |
|
try: |
|
s3_client.upload_file(image_path, S3_BUCKET_NAME, s3_path) |
|
return f"https://{S3_BUCKET_NAME}.s3.{AWS_REGION}.amazonaws.com/{s3_path}" |
|
except Exception as e: |
|
logging.error(f"Error uploading image to S3: {e}") |
|
raise HTTPException(status_code=500, detail=f"Image upload failed: {str(e)}") |
|
|
|
|
|
def generate_random_sequence(): |
|
random_numbers = ''.join(random.choices(string.digits, k=12)) |
|
random_words = ''.join(random.choices(string.ascii_lowercase, k=11)) |
|
return f"{random_numbers}_{random_words}" |
|
|
|
|
|
|
|
|
|
try: |
|
flux_pipe = FluxPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
flux_pipe.enable_model_cpu_offload() |
|
logging.info("FluxPipeline loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load FluxPipeline: {e}") |
|
raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}") |
|
|
|
try: |
|
img_pipe = FluxImg2ImgPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
img_pipe.enable_model_cpu_offload() |
|
logging.info("FluxImg2ImgPipeline loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load FluxPipeline: {e}") |
|
raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}") |
|
|
|
try: |
|
inpainting_pipe = FluxInpaintPipeline.from_pretrained("pranavajay/flow", torch_dtype=torch.bfloat16) |
|
inpainting_pipe.enable_model_cpu_offload() |
|
logging.info("FluxInpaintPipeline loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load FluxInpaintPipeline: {e}") |
|
raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}") |
|
|
|
try: |
|
video = CogVideoXImageToVideoPipeline.from_pretrained( |
|
"THUDM/CogVideoX-5b-I2V", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
video.enable_sequential_cpu_offload() |
|
video.vae.enable_tiling() |
|
video.vae.enable_slicing() |
|
logging.info("CogVideoXImageToVideoPipeline loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to load CogVideoXImageToVideoPipeline: {e}") |
|
raise HTTPException(status_code=500, detail=f"Failed to load the model: {str(e)}") |
|
|
|
|
|
flux_controlnet_pipe = None |
|
|
|
|
|
|
|
|
|
request_timestamps = defaultdict(list) |
|
RATE_LIMIT = 30 |
|
TIME_WINDOW = 5 |
|
|
|
|
|
style_lora_mapping = { |
|
"Uncensored": {"path": "enhanceaiteam/Flux-uncensored", "triggered_word": "nsfw"}, |
|
"Logo": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", "triggered_word": "logo"}, |
|
"Yarn": {"path": "Shakker-Labs/FLUX.1-dev-LoRA-MiaoKa-Yarn-World", "triggered_word": "mkym this is made of wool"}, |
|
"Anime": {"path": "prithivMLmods/Canopus-LoRA-Flux-Anime", "triggered_word": "anime"}, |
|
"Comic": {"path": "wkplhc/comic", "triggered_word": "comic"} |
|
} |
|
|
|
adapter_controlnet_mapping = { |
|
"Canny": "InstantX/FLUX.1-dev-controlnet-canny", |
|
"Depth": "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", |
|
"Pose": "Shakker-Labs/FLUX.1-dev-ControlNet-Pose", |
|
"Upscale": "jasperai/Flux.1-dev-Controlnet-Upscaler" |
|
} |
|
|
|
|
|
class GenerateImageRequest(BaseModel): |
|
prompt: constr(min_length=1) |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = "https://enhanceai.s3.amazonaws.com/792e2322-77fe-4070-aac4-7fa8d9e29c11_1.png" |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
def log_request(key: str, query: str): |
|
with open("key.txt", "a") as f: |
|
f.write(f"{datetime.datetime.now()} - Key: {key} - Query: {query}\n") |
|
|
|
def apply_lora_style(pipe, style, prompt): |
|
""" Apply the specified LoRA style to the prompt and load weights. """ |
|
if style in style_lora_mapping: |
|
lora_path = style_lora_mapping[style]["path"] |
|
triggered_word = style_lora_mapping[style]["triggered_word"] |
|
pipe.load_lora_weights(lora_path) |
|
return f"{triggered_word} {prompt}" |
|
return prompt |
|
|
|
def set_controlnet_adapter(adapter: str, is_inpainting: bool = False): |
|
""" |
|
Set the ControlNet adapter for the pipeline. |
|
|
|
Parameters: |
|
adapter (str): The key to identify which ControlNet adapter to load. |
|
is_inpainting (bool, optional): Whether to use the inpainting pipeline. Defaults to False. |
|
|
|
Raises: |
|
ValueError: If the adapter is not found in the adapter_controlnet_mapping. |
|
""" |
|
global flux_controlnet_pipe |
|
|
|
|
|
if adapter not in adapter_controlnet_mapping: |
|
raise ValueError(f"Invalid ControlNet adapter: {adapter}") |
|
|
|
|
|
controlnet_model_path = adapter_controlnet_mapping[adapter] |
|
|
|
|
|
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16) |
|
|
|
|
|
pipeline_cls = FluxControlNetInpaintPipeline if is_inpainting else FluxControlNetPipeline |
|
|
|
|
|
flux_controlnet_pipe = pipeline_cls.from_pretrained( |
|
"pranavajay/flow", controlnet=controlnet, torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
flux_controlnet_pipe.to("cuda") |
|
|
|
logging.info(f"ControlNet adapter '{adapter}' loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rate_limit(user_key: str): |
|
""" Check if the user is exceeding the rate limit. """ |
|
current_time = time.time() |
|
|
|
|
|
request_timestamps[user_key] = [t for t in request_timestamps[user_key] if current_time - t < TIME_WINDOW] |
|
|
|
if len(request_timestamps[user_key]) >= RATE_LIMIT: |
|
logging.info(f"Rate limit exceeded for user_key: {user_key}") |
|
return False |
|
|
|
|
|
request_timestamps[user_key].append(current_time) |
|
return True |
|
|
|
@app.post("/text_to_image/") |
|
async def generate_image(req: GenerateImageRequest): |
|
seed = req.seed |
|
if not rate_limit(req.user_key): |
|
log_requests(req.user_key, req.prompt) |
|
|
|
retries = 3 |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
original_prompt = req.prompt |
|
|
|
|
|
if req.adapter: |
|
try: |
|
set_controlnet_adapter(req.adapter) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
apply_lora_style(flux_controlnet_pipe, req.style, req.prompt) |
|
|
|
|
|
|
|
try: |
|
control_image = load_image(req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = flux_controlnet_pipe( |
|
prompt=modified_prompt, |
|
guidance_scale=req.guidance_scale, |
|
height=req.height, |
|
width=req.width, |
|
num_inference_steps=req.num_inference_steps, |
|
num_images_per_prompt=req.num_images_per_prompt, |
|
control_image=control_image, |
|
generator=generator, |
|
controlnet_conditioning_scale=req.controlnet_conditioning_scale |
|
).images |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
apply_lora_style(flux_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = flux_pipe( |
|
prompt=modified_prompt, |
|
guidance_scale=req.guidance_scale, |
|
height=req.height, |
|
width=req.width, |
|
num_inference_steps=req.num_inference_steps, |
|
num_images_per_prompt=req.num_images_per_prompt, |
|
generator=generator |
|
).images |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
image_urls = [] |
|
for i, img in enumerate(images): |
|
image_path = f"generated_images/{generate_random_sequence()}.png" |
|
img.save(image_path) |
|
image_url = upload_image_to_s3(image_path, image_path) |
|
image_urls.append(image_url) |
|
os.remove(image_path) |
|
|
|
return {"status": "success", "output": image_url, "prompt": original_prompt, "height": req.height, "width": req.width, "scale": req.guidance_scale, "step": step, "sytle": req.sytle, "adapter": req.adapter} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate image after multiple attempts: {str(e)}") |
|
continue |
|
|
|
|
|
|
|
class GenerateImageToImageRequest(BaseModel): |
|
prompt: str = None |
|
image: str = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
strength: float = 0.7 |
|
guidance_scale: float = 7.5 |
|
seed: conint(ge=0, le=MAX_SEED) = 42 |
|
randomize_seed: bool = False |
|
height: conint(gt=0) = 768 |
|
width: conint(gt=0) = 1360 |
|
control_image_url: str = None |
|
controlnet_conditioning_scale: float = 0.6 |
|
num_inference_steps: conint(gt=0) = 50 |
|
num_images_per_prompt: conint(gt=0, le=5) = 1 |
|
style: str = None |
|
adapter: str = None |
|
user_key: str |
|
|
|
@app.post("/image_to_image/") |
|
async def generate_image_to_image(req: GenerateImageToImageRequest): |
|
seed = req.seed |
|
original_prompt = req.prompt |
|
modified_prompt = original_prompt |
|
|
|
|
|
if not rate_limit(req.user_key): |
|
log_requests(req.user_key, req.prompt if req.prompt else "No prompt") |
|
raise HTTPException(status_code=429, detail="Rate limit exceeded") |
|
|
|
retries = 3 |
|
|
|
for attempt in range(retries): |
|
try: |
|
|
|
if not req.prompt or req.prompt.strip() == "": |
|
raise ValueError("Prompt cannot be empty.") |
|
|
|
original_prompt = req.prompt |
|
|
|
|
|
if req.adapter: |
|
try: |
|
set_controlnet_adapter(req.adapter) |
|
except Exception as e: |
|
logging.error(f"Error setting ControlNet adapter: {e}") |
|
raise HTTPException(status_code=400, detail=f"Failed to load ControlNet adapter: {str(e)}") |
|
apply_lora_style(flux_controlnet_pipe, req.style, req.prompt) |
|
|
|
|
|
|
|
try: |
|
control_image = load_image(req.control_image_url) |
|
except Exception as e: |
|
logging.error(f"Error loading control image from URL: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid control image URL or image could not be loaded.") |
|
|
|
|
|
try: |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
images = flux_controlnet_pipe( |
|
prompt=modified_prompt, |
|
guidance_scale=req.guidance_scale, |
|
height=req.height, |
|
width=req.width, |
|
num_inference_steps=req.num_inference_steps, |
|
num_images_per_prompt=req.num_images_per_prompt, |
|
control_image=control_image, |
|
generator=generator, |
|
controlnet_conditioning_scale=req.controlnet_conditioning_scale |
|
).images |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images with ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation with ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
else: |
|
|
|
try: |
|
apply_lora_style(img_pipe, req.style, req.prompt) |
|
if req.randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.Generator().manual_seed(seed) |
|
source = load_image(req.image) |
|
images = img_pipe( |
|
prompt=modified_prompt, |
|
image=source, |
|
strength=req.strength, |
|
guidance_scale=req.guidance_scale, |
|
height=req.height, |
|
width=req.width, |
|
num_inference_steps=req.num_inference_steps, |
|
num_images_per_prompt=req.num_images_per_prompt, |
|
generator=generator |
|
).images |
|
except torch.cuda.OutOfMemoryError: |
|
logging.error("GPU out of memory error while generating images without ControlNet.") |
|
raise HTTPException(status_code=500, detail="GPU overload occurred while generating images. Try reducing the resolution or number of steps.") |
|
except Exception as e: |
|
logging.error(f"Error during image generation without ControlNet: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error during image generation: {str(e)}") |
|
|
|
|
|
image_urls = [] |
|
for i, img in enumerate(images): |
|
image_path = f"generated_images/{generate_random_sequence()}.png" |
|
img.save(image_path) |
|
image_url = upload_image_to_s3(image_path, image_path) |
|
image_urls.append(image_url) |
|
os.remove(image_path) |
|
|
|
return {"status": "success", "output": image_url, "prompt": original_prompt, "height": req.height, "width": width, "image": req.image, "strength": req.strength, "scale": req.guidance_scale, "step": step, "sytle": req.sytle, "adapter": req.adapter} |
|
|
|
except Exception as e: |
|
logging.error(f"Attempt {attempt + 1} failed: {e}") |
|
if attempt == retries - 1: |
|
raise HTTPException(status_code=500, detail=f"Failed to generate image after m |