ahmedmbutt commited on
Commit
d6d2c2a
1 Parent(s): 3f083f0

added safety checker

Browse files
Files changed (1) hide show
  1. main.py +36 -15
main.py CHANGED
@@ -1,32 +1,47 @@
1
- from fastapi import FastAPI, Request, Form, File, UploadFile
2
  from fastapi.responses import StreamingResponse
3
  from contextlib import asynccontextmanager
4
  from starlette.middleware.cors import CORSMiddleware
5
 
6
  from PIL import Image
7
  from io import BytesIO
 
8
  from diffusers import (
9
  AutoPipelineForText2Image,
10
  AutoPipelineForImage2Image,
11
  AutoPipelineForInpainting,
12
  )
 
13
 
14
 
15
  @asynccontextmanager
16
  async def lifespan(app: FastAPI):
17
- text2img = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo").to(
18
- "cpu"
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
21
  img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu")
22
 
23
  inpaint = AutoPipelineForInpainting.from_pipe(img2img).to("cpu")
24
 
25
  yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint}
26
 
27
- del text2img
28
- del img2img
29
  del inpaint
 
 
 
 
 
30
 
31
 
32
  app = FastAPI(lifespan=lifespan)
@@ -49,7 +64,9 @@ async def root():
49
 
50
  @app.post("/text-to-image/")
51
  async def text_to_image(
52
- request: Request, prompt: str = Form(...), num_inference_steps: int = Form(1)
 
 
53
  ):
54
  image = request.state.text2img(
55
  prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0
@@ -67,10 +84,11 @@ async def image_to_image(
67
  prompt: str = Form(...),
68
  init_image: UploadFile = File(...),
69
  num_inference_steps: int = Form(2),
70
- strength: float = Form(0.5),
71
  ):
72
- bytes = await init_image.read()
73
- init_image = Image.open(BytesIO(bytes))
 
74
  init_image = init_image.convert("RGB").resize((512, 512))
75
 
76
  image = request.state.img2img(
@@ -80,6 +98,7 @@ async def image_to_image(
80
  strength=strength,
81
  guidance_scale=0.0,
82
  ).images[0]
 
83
 
84
  bytes = BytesIO()
85
  image.save(bytes, "PNG")
@@ -93,14 +112,15 @@ async def inpainting(
93
  prompt: str = Form(...),
94
  init_image: UploadFile = File(...),
95
  mask_image: UploadFile = File(...),
96
- num_inference_steps: int = Form(3),
97
- strength: float = Form(0.5),
98
  ):
99
- bytes = await init_image.read()
100
- init_image = Image.open(BytesIO(bytes))
 
101
  init_image = init_image.convert("RGB").resize((512, 512))
102
- bytes = await mask_image.read()
103
- mask_image = Image.open(BytesIO(bytes))
104
  mask_image = mask_image.convert("RGB").resize((512, 512))
105
 
106
  image = request.state.inpaint(
@@ -111,6 +131,7 @@ async def inpainting(
111
  strength=strength,
112
  guidance_scale=0.0,
113
  ).images[0]
 
114
 
115
  bytes = BytesIO()
116
  image.save(bytes, "PNG")
 
1
+ from fastapi import FastAPI, Request, UploadFile, Form, File
2
  from fastapi.responses import StreamingResponse
3
  from contextlib import asynccontextmanager
4
  from starlette.middleware.cors import CORSMiddleware
5
 
6
  from PIL import Image
7
  from io import BytesIO
8
+ from transformers import CLIPFeatureExtractor
9
  from diffusers import (
10
  AutoPipelineForText2Image,
11
  AutoPipelineForImage2Image,
12
  AutoPipelineForInpainting,
13
  )
14
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
15
 
16
 
17
  @asynccontextmanager
18
  async def lifespan(app: FastAPI):
19
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
20
+ "openai/clip-vit-base-patch32"
21
  )
22
 
23
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
24
+ "CompVis/stable-diffusion-safety-checker"
25
+ )
26
+
27
+ text2img = AutoPipelineForText2Image.from_pretrained(
28
+ "stabilityai/sd-turbo",
29
+ safety_checker=safety_checker,
30
+ feature_extractor=feature_extractor,
31
+ ).to("cpu")
32
+
33
  img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu")
34
 
35
  inpaint = AutoPipelineForInpainting.from_pipe(img2img).to("cpu")
36
 
37
  yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint}
38
 
 
 
39
  del inpaint
40
+ del img2img
41
+ del text2img
42
+
43
+ del safety_checker
44
+ del feature_extractor
45
 
46
 
47
  app = FastAPI(lifespan=lifespan)
 
64
 
65
  @app.post("/text-to-image/")
66
  async def text_to_image(
67
+ request: Request,
68
+ prompt: str = Form(...),
69
+ num_inference_steps: int = Form(1),
70
  ):
71
  image = request.state.text2img(
72
  prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=0.0
 
84
  prompt: str = Form(...),
85
  init_image: UploadFile = File(...),
86
  num_inference_steps: int = Form(2),
87
+ strength: float = Form(1.0),
88
  ):
89
+ init_bytes = await init_image.read()
90
+ init_image = Image.open(BytesIO(init_bytes))
91
+ init_width, init_height = init_image.size
92
  init_image = init_image.convert("RGB").resize((512, 512))
93
 
94
  image = request.state.img2img(
 
98
  strength=strength,
99
  guidance_scale=0.0,
100
  ).images[0]
101
+ image = image.resize((init_width, init_height))
102
 
103
  bytes = BytesIO()
104
  image.save(bytes, "PNG")
 
112
  prompt: str = Form(...),
113
  init_image: UploadFile = File(...),
114
  mask_image: UploadFile = File(...),
115
+ num_inference_steps: int = Form(2),
116
+ strength: float = Form(1.0),
117
  ):
118
+ init_bytes = await init_image.read()
119
+ init_image = Image.open(BytesIO(init_bytes))
120
+ init_width, init_height = init_image.size
121
  init_image = init_image.convert("RGB").resize((512, 512))
122
+ mask_bytes = await mask_image.read()
123
+ mask_image = Image.open(BytesIO(mask_bytes))
124
  mask_image = mask_image.convert("RGB").resize((512, 512))
125
 
126
  image = request.state.inpaint(
 
131
  strength=strength,
132
  guidance_scale=0.0,
133
  ).images[0]
134
+ image = image.resize((init_width, init_height))
135
 
136
  bytes = BytesIO()
137
  image.save(bytes, "PNG")