Spaces:
Runtime error
Runtime error
VikramSingh178
commited on
Commit
β’
3e01790
1
Parent(s):
5e29265
Update SDXL-LoRA inference pipeline and model weights
Browse filesFormer-commit-id: 550c615e6a453f0586ab834a0366c230320361d5
- product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc +0 -0
- product_diffusion_api/routers/sdxl_text_to_image.py +42 -7
- scripts/__init__.py +0 -0
- scripts/__pycache__/config.cpython-310.pyc +0 -0
- scripts/config.py +1 -0
- scripts/load_pipeline.py +25 -0
- scripts/wandb/debug.log +1 -1
product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc
CHANGED
Binary files a/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
|
|
product_diffusion_api/routers/sdxl_text_to_image.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import APIRouter, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
import base64
|
@@ -6,9 +10,17 @@ from typing import List
|
|
6 |
import uuid
|
7 |
from diffusers import DiffusionPipeline
|
8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
router = APIRouter()
|
11 |
|
|
|
12 |
# Utility function to convert PIL image to base64 encoded JSON
|
13 |
def pil_to_b64_json(image):
|
14 |
# Generate a UUID for the image
|
@@ -19,6 +31,27 @@ def pil_to_b64_json(image):
|
|
19 |
return {"image_id": image_id, "b64_image": b64_image}
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# SDXLLoraInference class for running inference
|
23 |
class SDXLLoraInference:
|
24 |
"""
|
@@ -51,12 +84,7 @@ class SDXLLoraInference:
|
|
51 |
num_inference_steps: int,
|
52 |
guidance_scale: float,
|
53 |
) -> None:
|
54 |
-
self.pipe =
|
55 |
-
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
56 |
-
)
|
57 |
-
self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
|
58 |
-
self.pipe.load_lora_weights(self.model_path)
|
59 |
-
self.pipe.to('cuda')
|
60 |
self.prompt = prompt
|
61 |
self.negative_prompt = negative_prompt
|
62 |
self.num_images = num_images
|
@@ -79,6 +107,7 @@ class SDXLLoraInference:
|
|
79 |
).images[0]
|
80 |
return pil_to_b64_json(image)
|
81 |
|
|
|
82 |
# Input format for single request
|
83 |
class InputFormat(BaseModel):
|
84 |
prompt: str
|
@@ -87,10 +116,12 @@ class InputFormat(BaseModel):
|
|
87 |
negative_prompt: str
|
88 |
num_images: int
|
89 |
|
|
|
90 |
# Input format for batch requests
|
91 |
class BatchInputFormat(BaseModel):
|
92 |
batch_input: List[InputFormat]
|
93 |
|
|
|
94 |
# Endpoint for single request
|
95 |
@router.post("/sdxl_v0_lora_inference")
|
96 |
async def sdxl_v0_lora_inference(data: InputFormat):
|
@@ -104,6 +135,7 @@ async def sdxl_v0_lora_inference(data: InputFormat):
|
|
104 |
output_json = inference.run_inference()
|
105 |
return output_json
|
106 |
|
|
|
107 |
# Endpoint for batch requests
|
108 |
@router.post("/sdxl_v0_lora_inference/batch")
|
109 |
async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
|
@@ -122,7 +154,10 @@ async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
|
|
122 |
MAX_QUEUE_SIZE = 64
|
123 |
|
124 |
if len(data.batch_input) > MAX_QUEUE_SIZE:
|
125 |
-
raise HTTPException(
|
|
|
|
|
|
|
126 |
|
127 |
processed_requests = []
|
128 |
for item in data.batch_input:
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
sys.path.append("../scripts") # Path of the scripts directory
|
4 |
+
import config
|
5 |
from fastapi import APIRouter, HTTPException
|
6 |
from pydantic import BaseModel
|
7 |
import base64
|
|
|
10 |
import uuid
|
11 |
from diffusers import DiffusionPipeline
|
12 |
import torch
|
13 |
+
import torch_tensorrt
|
14 |
+
from functools import lru_cache
|
15 |
+
|
16 |
+
torch._inductor.config.conv_1x1_as_mm = True
|
17 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
18 |
+
torch._inductor.config.epilogue_fusion = False
|
19 |
+
torch._inductor.config.coordinate_descent_check_all_directions = True
|
20 |
|
21 |
router = APIRouter()
|
22 |
|
23 |
+
|
24 |
# Utility function to convert PIL image to base64 encoded JSON
|
25 |
def pil_to_b64_json(image):
|
26 |
# Generate a UUID for the image
|
|
|
31 |
return {"image_id": image_id, "b64_image": b64_image}
|
32 |
|
33 |
|
34 |
+
@lru_cache(maxsize=1)
|
35 |
+
def load_pipeline(model_name, adapter_name):
|
36 |
+
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(
|
37 |
+
"cuda"
|
38 |
+
)
|
39 |
+
pipe.load_lora_weights(adapter_name)
|
40 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
41 |
+
pipe.vae.to(memory_format=torch.channels_last)
|
42 |
+
# pipe.unet = torch.compile(
|
43 |
+
# pipe.unet,
|
44 |
+
# mode = 'max-autotime'
|
45 |
+
# )
|
46 |
+
|
47 |
+
pipe.fuse_qkv_projections()
|
48 |
+
|
49 |
+
return pipe
|
50 |
+
|
51 |
+
|
52 |
+
loaded_pipeline = load_pipeline(config.MODEL_NAME, config.ADAPTER_NAME)
|
53 |
+
|
54 |
+
|
55 |
# SDXLLoraInference class for running inference
|
56 |
class SDXLLoraInference:
|
57 |
"""
|
|
|
84 |
num_inference_steps: int,
|
85 |
guidance_scale: float,
|
86 |
) -> None:
|
87 |
+
self.pipe = loaded_pipeline
|
|
|
|
|
|
|
|
|
|
|
88 |
self.prompt = prompt
|
89 |
self.negative_prompt = negative_prompt
|
90 |
self.num_images = num_images
|
|
|
107 |
).images[0]
|
108 |
return pil_to_b64_json(image)
|
109 |
|
110 |
+
|
111 |
# Input format for single request
|
112 |
class InputFormat(BaseModel):
|
113 |
prompt: str
|
|
|
116 |
negative_prompt: str
|
117 |
num_images: int
|
118 |
|
119 |
+
|
120 |
# Input format for batch requests
|
121 |
class BatchInputFormat(BaseModel):
|
122 |
batch_input: List[InputFormat]
|
123 |
|
124 |
+
|
125 |
# Endpoint for single request
|
126 |
@router.post("/sdxl_v0_lora_inference")
|
127 |
async def sdxl_v0_lora_inference(data: InputFormat):
|
|
|
135 |
output_json = inference.run_inference()
|
136 |
return output_json
|
137 |
|
138 |
+
|
139 |
# Endpoint for batch requests
|
140 |
@router.post("/sdxl_v0_lora_inference/batch")
|
141 |
async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
|
|
|
154 |
MAX_QUEUE_SIZE = 64
|
155 |
|
156 |
if len(data.batch_input) > MAX_QUEUE_SIZE:
|
157 |
+
raise HTTPException(
|
158 |
+
status_code=400,
|
159 |
+
detail=f"Number of requests exceeds maximum queue size ({MAX_QUEUE_SIZE})",
|
160 |
+
)
|
161 |
|
162 |
processed_requests = []
|
163 |
for item in data.batch_input:
|
scripts/__init__.py
ADDED
File without changes
|
scripts/__pycache__/config.cpython-310.pyc
CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
|
|
scripts/config.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
2 |
VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
|
3 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
4 |
PROJECT_NAME = "Product Photography"
|
|
|
1 |
MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
2 |
+
ADAPTER_NAME = "VikramSingh178/sdxl-lora-finetune-product-caption"
|
3 |
VAE_NAME= "madebyollin/sdxl-vae-fp16-fix"
|
4 |
DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
5 |
PROJECT_NAME = "Product Photography"
|
scripts/load_pipeline.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from config import MODEL_NAME,ADAPTER_NAME
|
2 |
+
import torch
|
3 |
+
from diffusers import DiffusionPipeline
|
4 |
+
from wandb.integration.diffusers import autolog
|
5 |
+
from config import PROJECT_NAME
|
6 |
+
autolog(init=dict(project=PROJECT_NAME))
|
7 |
+
|
8 |
+
|
9 |
+
def load_pipeline(model_name, adapter_name):
|
10 |
+
pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to(
|
11 |
+
"cuda"
|
12 |
+
)
|
13 |
+
pipe.load_lora_weights(adapter_name)
|
14 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
15 |
+
pipe.vae.to(memory_format=torch.channels_last)
|
16 |
+
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
|
17 |
+
pipe.vae.decode = torch.compile(
|
18 |
+
pipe.vae.decode, mode="max-autotune", fullgraph=True
|
19 |
+
)
|
20 |
+
pipe.fuse_qkv_projections()
|
21 |
+
|
22 |
+
return pipe
|
23 |
+
|
24 |
+
loaded_pipeline = load_pipeline(MODEL_NAME, ADAPTER_NAME)
|
25 |
+
images = loaded_pipeline('toaster', num_inference_steps=30).images[0]
|
scripts/wandb/debug.log
CHANGED
@@ -1 +1 @@
|
|
1 |
-
run-
|
|
|
1 |
+
run-20240507_154024-2j1bt71e/logs/debug.log
|