Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Form, File, UploadFile | |
from fastapi.responses import StreamingResponse | |
from contextlib import asynccontextmanager | |
from starlette.middleware.cors import CORSMiddleware | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import ( | |
AutoPipelineForText2Image, | |
AutoPipelineForImage2Image, | |
AutoPipelineForInpainting, | |
) | |
async def lifespan(app: FastAPI): | |
text2img = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16" | |
).to("cpu") | |
img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu") | |
inpaint = AutoPipelineForInpainting.from_pipe(text2img).to("cpu") | |
yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint} | |
del text2img | |
del img2img | |
del inpaint | |
app = FastAPI(lifespan=lifespan) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return {"Hello": "World"} | |
async def text_to_image(request: Request, prompt: str = Form(...)): | |
image = request.state.text2img( | |
prompt=prompt, num_inference_steps=1, guidance_scale=0.0 | |
).images[0] | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |
async def image_to_image( | |
request: Request, prompt: str = Form(...), init_image: UploadFile = File(...) | |
): | |
bytes = await init_image.read() | |
init_image = Image.open(BytesIO(bytes)) | |
init_image = init_image.convert("RGB").resize((512, 512)) | |
image = request.state.img2img.pipe( | |
prompt, | |
image=init_image, | |
num_inference_steps=2, | |
strength=0.5, | |
guidance_scale=0.0, | |
).images[0] | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |
async def inpainting( | |
request: Request, | |
prompt: str = Form(...), | |
init_image: UploadFile = File(...), | |
mask_image: UploadFile = File(...), | |
): | |
bytes = await init_image.read() | |
init_image = Image.open(BytesIO(bytes)) | |
init_image = init_image.convert("RGB").resize((512, 512)) | |
bytes = await mask_image.read() | |
mask_image = Image.open(BytesIO(bytes)) | |
mask_image = mask_image.convert("RGB").resize((512, 512)) | |
image = request.state.inpaint.pipe( | |
prompt, | |
image=init_image, | |
mask_image=mask_image, | |
num_inference_steps=3, | |
strength=0.5, | |
guidance_scale=0.0, | |
).images[0] | |
bytes = BytesIO() | |
image.save(bytes, "PNG") | |
bytes.seek(0) | |
return StreamingResponse(bytes, media_type="image/png") | |