ahmedmbutt commited on
Commit
4fd0935
1 Parent(s): 3fe0b34

Update AutoPipelineForInpainting model in lifespan function

Browse files
Files changed (1) hide show
  1. main.py +3 -4
main.py CHANGED
@@ -3,7 +3,6 @@ from fastapi.responses import StreamingResponse
3
  from contextlib import asynccontextmanager
4
  from starlette.middleware.cors import CORSMiddleware
5
 
6
- import torch
7
  from PIL import Image
8
  from io import BytesIO
9
  from diffusers import (
@@ -21,7 +20,7 @@ async def lifespan(app: FastAPI):
21
 
22
  img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu")
23
 
24
- inpaint = AutoPipelineForInpainting.from_pipe(text2img).to("cpu")
25
 
26
  yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint}
27
 
@@ -68,7 +67,7 @@ async def image_to_image(
68
  init_image = Image.open(BytesIO(bytes))
69
  init_image = init_image.convert("RGB").resize((512, 512))
70
 
71
- image = request.state.img2img.pipe(
72
  prompt,
73
  image=init_image,
74
  num_inference_steps=2,
@@ -96,7 +95,7 @@ async def inpainting(
96
  mask_image = Image.open(BytesIO(bytes))
97
  mask_image = mask_image.convert("RGB").resize((512, 512))
98
 
99
- image = request.state.inpaint.pipe(
100
  prompt,
101
  image=init_image,
102
  mask_image=mask_image,
 
3
  from contextlib import asynccontextmanager
4
  from starlette.middleware.cors import CORSMiddleware
5
 
 
6
  from PIL import Image
7
  from io import BytesIO
8
  from diffusers import (
 
20
 
21
  img2img = AutoPipelineForImage2Image.from_pipe(text2img).to("cpu")
22
 
23
+ inpaint = AutoPipelineForInpainting.from_pipe(img2img).to("cpu")
24
 
25
  yield {"text2img": text2img, "img2img": img2img, "inpaint": inpaint}
26
 
 
67
  init_image = Image.open(BytesIO(bytes))
68
  init_image = init_image.convert("RGB").resize((512, 512))
69
 
70
+ image = request.state.img2img(
71
  prompt,
72
  image=init_image,
73
  num_inference_steps=2,
 
95
  mask_image = Image.open(BytesIO(bytes))
96
  mask_image = mask_image.convert("RGB").resize((512, 512))
97
 
98
+ image = request.state.inpaint(
99
  prompt,
100
  image=init_image,
101
  mask_image=mask_image,