|
import time |
|
from io import BytesIO |
|
import modal |
|
from huggingface_hub import login |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import base64 |
|
import sys |
|
import requests |
|
import os |
|
from safetensors.torch import load_file |
|
import io |
|
import os |
|
|
|
import modal |
|
|
|
app = modal.App() |
|
|
|
|
|
@app.function( |
|
image=modal.Image.debian_slim().pip_install("torch", "diffusers[torch]", "transformers", "ftfy"), |
|
secrets=[modal.Secret.from_name("huggingface-token")], |
|
gpu="any", |
|
) |
|
def run_stable_diffusion(prompt: str): |
|
from diffusers import StableDiffusionPipeline |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
use_auth_token=os.environ["CLIENT_TOKEN1"], |
|
).to("cuda") |
|
|
|
image = pipe(prompt, num_inference_steps=10).images[0] |
|
|
|
buf = io.BytesIO() |
|
image.save(buf, format="PNG") |
|
img_bytes = buf.getvalue() |
|
|
|
return img_bytes |
|
|
|
|
|
@app.local_entrypoint() |
|
def main(): |
|
img_bytes = run_stable_diffusion.remote("Wu-Tang Clan climbing Mount Everest") |
|
with open("/tmp/output.png", "wb") as f: |
|
f.write(img_bytes) |
|
|
|
|
|
cuda_version = "12.4.0" |
|
flavor = "devel" |
|
operating_sys = "ubuntu22.04" |
|
tag = f"{cuda_version}-{flavor}-{operating_sys}" |
|
cuda_dev_image = modal.Image.from_registry( |
|
f"nvidia/cuda:{tag}", add_python="3.11" |
|
).entrypoint([]) |
|
|
|
diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b" |
|
|
|
flux_image = ( |
|
cuda_dev_image.apt_install( |
|
"git", |
|
"libglib2.0-0", |
|
"libsm6", |
|
"libxrender1", |
|
"libxext6", |
|
"ffmpeg", |
|
"libgl1", |
|
) |
|
.pip_install( |
|
"invisible_watermark==0.2.0", |
|
"peft==0.10.0", |
|
"transformers==4.44.0", |
|
"huggingface_hub[hf_transfer]==0.26.2", |
|
"accelerate==0.33.0", |
|
"safetensors==0.4.4", |
|
"sentencepiece==0.2.0", |
|
"torch==2.5.0", |
|
f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}", |
|
"numpy<2", |
|
"fastapi==0.104.1", |
|
"uvicorn==0.24.0", |
|
) |
|
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache"}) |
|
) |
|
|
|
flux_image = flux_image.env( |
|
{ |
|
"TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache", |
|
"TORCHINDUCTOR_FX_GRAPH_CACHE": "1", |
|
} |
|
) |
|
|
|
app = modal.App("flux-api-server", image=flux_image, secrets=[modal.Secret.from_name("huggingface-token")]) |
|
|
|
with flux_image.imports(): |
|
import torch |
|
from diffusers import FluxPipeline |
|
|
|
MINUTES = 60 |
|
VARIANT = "dev" |
|
NUM_INFERENCE_STEPS = 50 |
|
|
|
class ImageRequest(BaseModel): |
|
prompt: str |
|
num_inference_steps: int = 50 |
|
width: int = 1024 |
|
height: int = 1024 |
|
|
|
class ImageResponse(BaseModel): |
|
image_base64: str |
|
generation_time: float |
|
|
|
@app.cls( |
|
gpu="H200", |
|
scaledown_window=20 * MINUTES, |
|
timeout=60 * MINUTES, |
|
volumes={ |
|
"/cache": modal.Volume.from_name("hf-hub-cache", create_if_missing=True), |
|
"/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True), |
|
"/root/.triton": modal.Volume.from_name("triton-cache", create_if_missing=True), |
|
"/root/.inductor-cache": modal.Volume.from_name( |
|
"inductor-cache", create_if_missing=True |
|
), |
|
}, |
|
) |
|
class Model: |
|
compile: bool = modal.parameter(default=False) |
|
|
|
lora_loaded = False |
|
lora_path = "/cache/flux.1_lora_flyway_doodle-poster.safetensors" |
|
lora_url = "https://huggingface.co/RajputVansh/SG161222-DISTILLED-IITI-VANSH-RUHELA/resolve/main/flux.1_lora_flyway_doodle-poster.safetensors?download=true" |
|
|
|
def download_lora_from_url(self, url, save_path): |
|
"""Download LoRA with proper error handling""" |
|
try: |
|
print(f"📥 Downloading LoRA from {url}") |
|
response = requests.get(url, timeout=300) |
|
response.raise_for_status() |
|
|
|
with open(save_path, "wb") as f: |
|
f.write(response.content) |
|
|
|
print(f"✅ LoRA downloaded successfully to {save_path}") |
|
print(f"📊 File size: {len(response.content)} bytes") |
|
return True |
|
except Exception as e: |
|
print(f"❌ LoRA download failed: {str(e)}") |
|
return False |
|
|
|
def verify_lora_file(self, lora_path): |
|
"""Verify that the LoRA file is valid""" |
|
try: |
|
if not os.path.exists(lora_path): |
|
return False, "File does not exist" |
|
|
|
file_size = os.path.getsize(lora_path) |
|
if file_size == 0: |
|
return False, "File is empty" |
|
|
|
|
|
try: |
|
load_file(lora_path) |
|
return True, f"Valid LoRA file ({file_size} bytes)" |
|
except Exception as e: |
|
return False, f"Invalid LoRA file: {str(e)}" |
|
|
|
except Exception as e: |
|
return False, f"Error verifying file: {str(e)}" |
|
|
|
@modal.enter() |
|
def enter(self): |
|
from huggingface_hub import login |
|
import os |
|
|
|
|
|
token = os.environ["huggingface_token"] |
|
login(token) |
|
|
|
|
|
if not os.path.exists(self.lora_path): |
|
print("📥 LoRA not found, downloading...") |
|
download_success = self.download_lora_from_url(self.lora_url, self.lora_path) |
|
if not download_success: |
|
print("❌ Failed to download LoRA, continuing without it") |
|
self.lora_loaded = False |
|
else: |
|
print("📁 LoRA file found in cache") |
|
|
|
|
|
is_valid, message = self.verify_lora_file(self.lora_path) |
|
print(f"🔍 LoRA verification: {message}") |
|
|
|
|
|
from diffusers import FluxPipeline |
|
import torch |
|
|
|
print("🚀 Loading Flux model...") |
|
pipe = FluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
if is_valid: |
|
try: |
|
print(f"🔄 Loading LoRA from {self.lora_path}") |
|
pipe.load_lora_weights(self.lora_path) |
|
print("✅ LoRA successfully loaded!") |
|
self.lora_loaded = True |
|
|
|
|
|
print("🧪 Testing LoRA integration...") |
|
|
|
|
|
except Exception as e: |
|
print(f"❌ LoRA loading failed: {str(e)}") |
|
self.lora_loaded = False |
|
else: |
|
print("⚠️ LoRA not loaded due to verification failure") |
|
self.lora_loaded = False |
|
|
|
|
|
self.pipe = optimize(pipe, compile=self.compile) |
|
|
|
print(f"🎯 Model ready! LoRA status: {'✅ Loaded' if self.lora_loaded else '❌ Not loaded'}") |
|
|
|
|
|
@modal.method() |
|
def get_model_status(self) -> dict: |
|
"""Get detailed model and LoRA status""" |
|
lora_file_info = {} |
|
if os.path.exists(self.lora_path): |
|
try: |
|
file_size = os.path.getsize(self.lora_path) |
|
lora_file_info = { |
|
"exists": True, |
|
"size_bytes": file_size, |
|
"size_mb": round(file_size / (1024 * 1024), 2) |
|
} |
|
except: |
|
lora_file_info = {"exists": False} |
|
else: |
|
lora_file_info = {"exists": False} |
|
|
|
return { |
|
"status": "ready", |
|
"lora_loaded": self.lora_loaded, |
|
"lora_path": self.lora_path, |
|
"model_info": { |
|
"base_model": "black-forest-labs/FLUX.1-dev", |
|
"lora_file": lora_file_info, |
|
"lora_url": self.lora_url |
|
} |
|
} |
|
|
|
@modal.method() |
|
def inference(self, prompt: str, num_inference_steps: int = 50, width: int = 1024, height: int = 1024) -> dict: |
|
|
|
final_prompt = prompt |
|
|
|
print(f"🎨 Generating image:") |
|
print(f" Original prompt: {prompt}") |
|
print(f" Final prompt: {final_prompt}") |
|
print(f" Dimensions: {width}x{height}") |
|
print(f" LoRA status: {'✅ Active' if self.lora_loaded else '❌ Inactive'}") |
|
|
|
start_time = time.time() |
|
|
|
out = self.pipe( |
|
final_prompt, |
|
output_type="pil", |
|
num_inference_steps=num_inference_steps, |
|
width=width, |
|
height=height, |
|
max_sequence_length=512 |
|
).images[0] |
|
|
|
|
|
byte_stream = BytesIO() |
|
out.save(byte_stream, format="PNG") |
|
image_bytes = byte_stream.getvalue() |
|
image_base64 = base64.b64encode(image_bytes).decode('utf-8') |
|
|
|
generation_time = time.time() - start_time |
|
print(f"✅ Generated image in {generation_time:.2f} seconds") |
|
|
|
return { |
|
"image_base64": image_base64, |
|
"generation_time": generation_time, |
|
"final_prompt": final_prompt, |
|
"lora_used": self.lora_loaded |
|
} |
|
|
|
fastapi_app = FastAPI(title="Flux Image Generation API") |
|
|
|
|
|
model_instance = Model(compile=False) |
|
|
|
@fastapi_app.post("/generate", response_model=ImageResponse) |
|
async def generate_image(request: ImageRequest): |
|
try: |
|
print(f"Received request: {request.prompt} at {request.width}x{request.height}") |
|
result = model_instance.inference.remote( |
|
request.prompt, |
|
request.num_inference_steps, |
|
request.width, |
|
request.height |
|
) |
|
return ImageResponse(**result) |
|
except Exception as e: |
|
print(f"Error generating image: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@fastapi_app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "message": "Flux API server is running"} |
|
|
|
@app.function( |
|
image=flux_image.pip_install("fastapi", "uvicorn"), |
|
keep_warm -> min_containers |
|
) |
|
@modal.asgi_app() |
|
def fastapi_server(): |
|
return fastapi_app |
|
|
|
def optimize(pipe, compile=True): |
|
|
|
pipe.transformer.fuse_qkv_projections() |
|
pipe.vae.fuse_qkv_projections() |
|
|
|
|
|
pipe.transformer.to(memory_format=torch.channels_last) |
|
pipe.vae.to(memory_format=torch.channels_last) |
|
|
|
if not compile: |
|
return pipe |
|
|
|
|
|
config = torch._inductor.config |
|
config.disable_progress = False |
|
config.conv_1x1_as_mm = True |
|
config.coordinate_descent_tuning = True |
|
config.coordinate_descent_check_all_directions = True |
|
config.epilogue_fusion = False |
|
|
|
|
|
pipe.transformer = torch.compile( |
|
pipe.transformer, mode="max-autotune", fullgraph=True |
|
) |
|
pipe.vae.decode = torch.compile( |
|
pipe.vae.decode, mode="max-autotune", fullgraph=True |
|
) |
|
|
|
|
|
print("🔦 Running torch compilation (may take up to 20 minutes)...") |
|
pipe( |
|
"dummy prompt to trigger torch compilation", |
|
output_type="pil", |
|
num_inference_steps=NUM_INFERENCE_STEPS, |
|
).images[0] |
|
print("🔦 Finished torch compilation") |
|
|
|
return pipe |
|
|
|
if __name__ == "__main__": |
|
print("Starting Modal Flux API server...") |
|
|
|
|
|
result = model_instance.inference( |
|
request.prompt, |
|
request.num_inference_steps, |
|
request.width, |
|
request.height |
|
) |