File size: 3,595 Bytes
6d3950c
 
 
 
 
 
 
 
 
 
 
 
 
d2a2d86
6d3950c
 
d2a2d86
6d3950c
 
 
 
 
 
 
 
 
d2a2d86
6d3950c
 
d2a2d86
 
6d3950c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, UploadFile, File,APIRouter,HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional
from PIL import Image
import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
from utils import (accelerator, ImageAugmentation, clear_memory)
import hydra
from omegaconf import DictConfig
import lightning.pytorch as pl
import io

# Define FastAPI app
router = APIRouter()

class InpaintingRequest(BaseModel):
    
    prompt: str
    negative_prompt: Optional[str] = None
    num_inference_steps: int
    strength: float
    guidance_scale: float
    target_width: int
    target_height: int

class InpaintingBatchRequest(BaseModel):
    batch_input: List[InpaintingRequest]




def pil_to_s3_json(image: Image.Image, file_name: str):
    image_id = str(uuid.uuid4())
    s3_uploader = S3ManagerService()
    image_bytes = io.BytesIO()
    image.save(image_bytes, format="PNG")
    image_bytes.seek(0)

    unique_file_name = s3_uploader.generate_unique_file_name(file_name)
    s3_uploader.upload_file(image_bytes, unique_file_name)
    signed_url = s3_uploader.generate_signed_url(unique_file_name, exp=43200)  # 12 hours
    return {"image_id": image_id, "url": signed_url}

class AutoPaintingPipeline:
    def __init__(self, model_name: str, image: Image.Image, mask_image: Image.Image, target_width: int, target_height: int):
        self.model_name = model_name
        self.device = accelerator()
        self.pipeline = AutoPipelineForInpainting.from_pretrained(self.model_name, torch_dtype=torch.float16)
        self.image = load_image(image)
        self.mask_image = load_image(mask_image)
        self.target_width = target_width
        self.target_height = target_height
        self.pipeline.to(self.device)
        
    def run_inference(self, prompt: str, negative_prompt: Optional[str], num_inference_steps: int, strength: float, guidance_scale: float):
        clear_memory()
        image = load_image(self.image)
        mask_image = load_image(self.mask_image)
        output = self.pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=image,
            mask_image=mask_image,
            num_inference_steps=num_inference_steps,
            strength=strength,
            guidance_scale=guidance_scale,
            height=self.target_height,
            width=self.target_width
        ).images[0]
        return output

@app.post("/inpaint/")
async def inpaint(
    file: UploadFile = File(...),
    request: InpaintingRequest
):
    image = Image.open(file.file)
    augmenter = ImageAugmentation(target_width=request.target_width, target_height=request.target_height)  # Use fixed size or set dynamically
    extended_image = augmenter.extend_image(image)
    mask_image = augmenter.generate_mask_from_bbox(extended_image, 'segmentation_model', 'detection_model')
    mask_image = augmenter.invert_mask(mask_image)

    pipeline = AutoPaintingPipeline(
        model_name="model_name",
        image=extended_image,
        mask_image=mask_image,
        target_width=request.target_width,
        target_height=request.target_height
    )
    output_image = pipeline.run_inference(
        prompt=request.prompt,
        negative_prompt=request.negative_prompt,
        num_inference_steps=request.num_inference_steps,
        strength=request.strength,
        guidance_scale=request.guidance_scale,
        
    )

    
    result = pil_to_s3_json(output_image, "output_image.png")
    return result