File size: 7,404 Bytes
3e01790
 
 
32a04ae
07da480
 
 
32a04ae
 
 
 
3e01790
1979816
 
 
f3cfe0c
3e01790
7f2257d
3e01790
 
 
 
51c84d5
 
07da480
 
 
3e01790
51c84d5
1979816
 
07da480
7b86905
 
 
 
 
 
 
 
 
 
32a04ae
07da480
 
 
32a04ae
ed57be9
1979816
7b86905
 
 
 
 
 
 
 
 
 
 
 
1979816
 
 
 
 
 
 
 
 
 
 
 
 
ed57be9
3e01790
7b86905
 
 
 
 
 
 
 
 
 
 
 
3e01790
8af21e8
51c84d5
3e01790
7b86905
 
3e01790
 
 
 
7b86905
3e01790
 
32a04ae
ed57be9
 
32a04ae
07da480
ed57be9
32a04ae
 
 
 
 
 
 
 
 
 
 
ed57be9
32a04ae
 
 
 
ed57be9
07da480
 
32a04ae
 
 
 
 
 
d8b5430
07da480
3e01790
ed57be9
07da480
 
32a04ae
 
d8b5430
 
ed57be9
d8b5430
ed57be9
32a04ae
07da480
1979816
 
 
 
 
ed57be9
1979816
ed57be9
32a04ae
 
07da480
 
32a04ae
 
07da480
1979816
d8b5430
 
 
 
1979816
 
 
07da480
3e01790
32a04ae
07da480
32a04ae
 
 
 
 
d8b5430
07da480
3e01790
32a04ae
 
 
07da480
3e01790
32a04ae
07da480
 
32a04ae
 
 
 
 
 
d8b5430
32a04ae
 
 
 
3e01790
1979816
32a04ae
 
07da480
32a04ae
07da480
 
32a04ae
ed57be9
07da480
32a04ae
 
 
 
07da480
c58035a
32a04ae
 
3e01790
 
 
 
32a04ae
 
 
 
 
 
 
 
 
d2a2d86
32a04ae
 
 
07da480
32a04ae
654965f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import sys
sys.path.append("../scripts")  # Path of the scripts directory
import config
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
import base64
from io import BytesIO
from typing import List
import uuid
from diffusers import DiffusionPipeline
import torch
from functools import lru_cache
from s3_manager import S3ManagerService
from PIL import Image
import io
from scripts.utils import accelerator

device = accelerator()
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

router = APIRouter()





def pil_to_b64_json(image):
    """
    Converts a PIL image to a base64-encoded JSON object.

    Args:
        image (PIL.Image.Image): The PIL image object to be converted.

    Returns:
        dict: A dictionary containing the image ID and the base64-encoded image.

    """
    image_id = str(uuid.uuid4())
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return {"image_id": image_id, "b64_image": b64_image}


def pil_to_s3_json(image: Image.Image, file_name) -> str:
    """
    Uploads a PIL image to Amazon S3 and returns a JSON object containing the image ID and the signed URL.

    Args:
        image (PIL.Image.Image): The PIL image to be uploaded.
        file_name (str): The name of the file.

    Returns:
        dict: A JSON object containing the image ID and the signed URL.

    """
    image_id = str(uuid.uuid4())
    s3_uploader = S3ManagerService()
    image_bytes = io.BytesIO()
    image.save(image_bytes, format="PNG")
    image_bytes.seek(0)

    unique_file_name = s3_uploader.generate_unique_file_name(file_name)
    s3_uploader.upload_file(image_bytes, unique_file_name)
    signed_url = s3_uploader.generate_signed_url(
        unique_file_name, exp=43200
    )  # 12 hours
    return {"image_id": image_id, "url": signed_url}


@lru_cache(maxsize=1)
def load_pipeline(model_name, adapter_name):
    """
    Load the diffusion pipeline with the specified model and adapter names.

    Args:
        model_name (str): The name of the pretrained model.
        adapter_name (str): The name of the adapter.

    Returns:
        DiffusionPipeline: The loaded diffusion pipeline.
    """
    pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
    pipe.load_lora_weights(adapter_name)
    pipe.fuse_lora()
    pipe.unload_lora_weights()
    pipe.unet.to(memory_format=torch.channels_last)
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
    pipe.vae.decode = torch.compile(pipe.vae.decode, mode="reduce-overhead")
    pipe.fuse_qkv_projections()
    return pipe


loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME)


# SDXLLoraInference class for running inference
class SDXLLoraInference:
    """
    Class for performing SDXL Lora inference.

    Args:
        prompt (str): The prompt for generating the image.
        negative_prompt (str): The negative prompt for generating the image.
        num_images (int): The number of images to generate.
        num_inference_steps (int): The number of inference steps to perform.
        guidance_scale (float): The scale for guiding the generation process.

    Attributes:
        pipe (DiffusionPipeline): The pre-trained diffusion pipeline.
        prompt (str): The prompt for generating the image.
        negative_prompt (str): The negative prompt for generating the image.
        num_images (int): The number of images to generate.
        num_inference_steps (int): The number of inference steps to perform.
        guidance_scale (float): The scale for guiding the generation process.

    Methods:
        run_inference: Runs the inference process and returns the generated image.
    """

    def __init__(
        self,
        prompt: str,
        negative_prompt: str,
        num_images: int,
        num_inference_steps: int,
        guidance_scale: float,
        mode :str
    ) -> None:
        self.pipe = loaded_pipeline
        self.prompt = prompt
        self.negative_prompt = negative_prompt
        self.num_images = num_images
        self.num_inference_steps = num_inference_steps
        self.guidance_scale = guidance_scale
        self.mode = mode 
        

    def run_inference(self) -> str:
        """
        Runs the inference process and returns the generated image.

        Parameters:
            mode (str): The mode for returning the generated image.
                        Possible values: "b64_json", "s3_json".
                        Defaults to "b64_json".

        Returns:
            str: The generated image in the specified format.
        """
        image = self.pipe(
            prompt=self.prompt,
            num_inference_steps=self.num_inference_steps,
            guidance_scale=self.guidance_scale,
            negative_prompt=self.negative_prompt,
            num_images_per_prompt=self.num_images,
        ).images[0]
        
        if self.mode == "s3_json":
            s3_url = pil_to_s3_json(image,'sdxl_image')
            return s3_url
        elif self.mode == "b64_json":
            return pil_to_b64_json(image)
        else:
            raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")


# Input format for single request
class InputFormat(BaseModel):
    prompt: str
    num_inference_steps: int
    guidance_scale: float
    negative_prompt: str
    num_images: int
    mode: str


# Input format for batch requests
class BatchInputFormat(BaseModel):
    batch_input: List[InputFormat]


# Endpoint for single request
@router.post("/sdxl_v0_lora_inference")
async def sdxl_v0_lora_inference(data: InputFormat):
    inference = SDXLLoraInference(
        data.prompt,
        data.negative_prompt,
        data.num_images,
        data.num_inference_steps,
        data.guidance_scale,
        data.mode
    )
    output_json = inference.run_inference()
    return output_json



@router.post("/sdxl_v0_lora_inference/batch")
async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
    """
    Perform batch inference for SDXL V0 LoRa model.

    Args:
        data (BatchInputFormat): The input data containing a batch of requests.

    Returns:
        dict: A dictionary containing the message and processed requests data.

    Raises:
        HTTPException: If the number of requests exceeds the maximum queue size.
    """
    MAX_QUEUE_SIZE = 64

    if len(data.batch_input) > MAX_QUEUE_SIZE:
        raise HTTPException(
            status_code=400,
            detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})",
        )

    processed_requests = []
    for item in data.batch_input:
        inference = SDXLLoraInference(
            item.prompt,
            item.negative_prompt,
            item.num_images,
            item.num_inference_steps,
            item.guidance_scale,
            item.mode,
        )
        output_json = inference.run_inference()
        processed_requests.append(output_json)

    return {"message": "Requests processed successfully", "data": processed_requests}