from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import random import torch from diffusers import DiffusionPipeline import boto3 from io import BytesIO import time import os # S3 Configuration S3_BUCKET = "afri" S3_REGION = "eu-west-3" S3_ACCESS_KEY_ID = "AKIAQQABC7IQWFLKSE62" S3_SECRET_ACCESS_KEY = "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" # Set up S3 client s3_client = boto3.client('s3', region_name=S3_REGION, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY) dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 app = FastAPI() class InferenceRequest(BaseModel): prompt: str seed: int = 42 randomize_seed: bool = False width: int = 1024 height: int = 1024 guidance_scale: float = 5.0 num_inference_steps: int = 28 def save_image_to_s3(image): img_byte_arr = BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() filename = f"generated_image_{int(time.time())}.png" s3_client.put_object(Bucket=S3_BUCKET, Key=filename, Body=img_byte_arr, ContentType='image/png') url = f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com/{filename}" return url @app.post("/infer") async def infer(request: InferenceRequest): if request.randomize_seed: seed = random.randint(0, MAX_SEED) else: seed = request.seed generator = torch.Generator().manual_seed(seed) try: image = pipe( prompt=request.prompt, width=request.width, height=request.height, num_inference_steps=request.num_inference_steps, generator=generator, guidance_scale=request.guidance_scale ).images[0] image_url = save_image_to_s3(image) return {"image_url": image_url, "seed": seed} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Welcome to the IG API"}