Spaces:
Sleeping
Sleeping
File size: 4,179 Bytes
d6d2c2a bd94aa2 5cf7dee d6d2c2a bd94aa2 d6d2c2a ee55332 bd94aa2 d6d2c2a ee55332 d6d2c2a bd94aa2 4fd0935 bd94aa2 d6d2c2a 5cf7dee d6d2c2a bd94aa2 845eacd d6d2c2a 845eacd 60d26fc 5cf7dee 60d26fc bd94aa2 845eacd d6d2c2a bd94aa2 d6d2c2a bd94aa2 60d26fc bd94aa2 845eacd bd94aa2 60d26fc bd94aa2 d6d2c2a bd94aa2 d6d2c2a bd94aa2 d6d2c2a bd94aa2 60d26fc bd94aa2 845eacd bd94aa2 60d26fc bd94aa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from fastapi import FastAPI, Request, UploadFile, Form, File
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from PIL import Image
from io import BytesIO
from diffusers import (
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
)
from transformers import CLIPFeatureExtractor
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
@asynccontextmanager
async def lifespan(app: FastAPI):
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
text2img = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sd-turbo",
safety_checker=safety_checker,
feature_extractor=feature_extractor,
).to("cpu")
img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu")
inpaint = AutoPipelineForInpainting.from_pipe(img2img).to("cpu")
yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint}
del inpaint
del img2img
del text2img
del safety_checker
del feature_extractor
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"Hello": "World"}
@app.post("/text-to-image/")
async def text_to_image(
request: Request,
prompt: str = Form(...),
num_inference_steps: int = Form(1),
):
results = request.state.text2img(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
)
if not results.nsfw_content_detected[0]:
image = results.images[0]
else:
image = Image.new("RGB", (512, 512), "black")
bytes = BytesIO()
image.save(bytes, "PNG")
bytes.seek(0)
return StreamingResponse(bytes, media_type="image/png")
@app.post("/image-to-image/")
async def image_to_image(
request: Request,
prompt: str = Form(...),
init_image: UploadFile = File(...),
num_inference_steps: int = Form(2),
strength: float = Form(1.0),
):
init_bytes = await init_image.read()
init_image = Image.open(BytesIO(init_bytes))
init_width, init_height = init_image.size
init_image = init_image.convert("RGB").resize((512, 512))
results = request.state.img2img(
prompt,
image=init_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=0.0,
)
if not results.nsfw_content_detected[0]:
image = results.images[0].resize((init_width, init_height))
else:
image = Image.new("RGB", (512, 512), "black")
bytes = BytesIO()
image.save(bytes, "PNG")
bytes.seek(0)
return StreamingResponse(bytes, media_type="image/png")
@app.post("/inpainting/")
async def inpainting(
request: Request,
prompt: str = Form(...),
init_image: UploadFile = File(...),
mask_image: UploadFile = File(...),
num_inference_steps: int = Form(2),
strength: float = Form(1.0),
):
init_bytes = await init_image.read()
init_image = Image.open(BytesIO(init_bytes))
init_width, init_height = init_image.size
init_image = init_image.convert("RGB").resize((512, 512))
mask_bytes = await mask_image.read()
mask_image = Image.open(BytesIO(mask_bytes))
mask_image = mask_image.convert("RGB").resize((512, 512))
results = request.state.inpaint(
prompt,
image=init_image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=0.0,
)
if not results.nsfw_content_detected[0]:
image = results.images[0].resize((init_width, init_height))
else:
image = Image.new("RGB", (512, 512), "black")
bytes = BytesIO()
image.save(bytes, "PNG")
bytes.seek(0)
return StreamingResponse(bytes, media_type="image/png")
|