VikramSingh178 commited on
Commit
51c84d5
1 Parent(s): 27d6043

Update SDXL-LoRA inference pipeline and model weights

Browse files

Former-commit-id: c803a77350d44492c1875b98c83808ed560b6cef

product_diffusion_api/__pycache__/endpoints.cpython-310.pyc CHANGED
Binary files a/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc and b/product_diffusion_api/__pycache__/endpoints.cpython-310.pyc differ
 
product_diffusion_api/endpoints.py CHANGED
@@ -28,7 +28,8 @@ async def root():
28
  'author': 'Vikramjeet Singh',
29
  'contact': {
30
  'email': 'singh.vikram.1782000@gmail.com',
31
- 'github': 'https://github.com/vikramxD'
 
32
  },
33
  'license': 'MIT',
34
  }
 
28
  'author': 'Vikramjeet Singh',
29
  'contact': {
30
  'email': 'singh.vikram.1782000@gmail.com',
31
+ 'github': 'https://github.com/vikramxD',
32
+ 'website': 'https://vikramxd.github.io'
33
  },
34
  'license': 'MIT',
35
  }
product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc CHANGED
Binary files a/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
 
product_diffusion_api/routers/sdxl_text_to_image.py CHANGED
@@ -1,5 +1,5 @@
1
  import sys
2
-
3
  sys.path.append("../scripts") # Path of the scripts directory
4
  import config
5
  from fastapi import APIRouter, HTTPException
@@ -10,17 +10,48 @@ from typing import List
10
  import uuid
11
  from diffusers import DiffusionPipeline
12
  import torch
13
- import torch_tensorrt
14
  from functools import lru_cache
15
 
16
  torch._inductor.config.conv_1x1_as_mm = True
17
  torch._inductor.config.coordinate_descent_tuning = True
18
  torch._inductor.config.epilogue_fusion = False
19
  torch._inductor.config.coordinate_descent_check_all_directions = True
 
 
20
 
21
  router = APIRouter()
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Utility function to convert PIL image to base64 encoded JSON
25
  def pil_to_b64_json(image):
26
  # Generate a UUID for the image
@@ -37,14 +68,14 @@ def load_pipeline(model_name, adapter_name):
37
  "cuda"
38
  )
39
  pipe.load_lora_weights(adapter_name)
40
- pipe.unet = torch.compile(
41
- pipe.unet,
42
- mode = 'max-autotune'
43
- )
44
  pipe.unet.to(memory_format=torch.channels_last)
45
- pipe.vae.to(memory_format=torch.channels_last)
46
-
47
  pipe.fuse_qkv_projections()
 
 
 
 
48
 
49
  return pipe
50
 
 
1
  import sys
2
+ from torchao.quantization import apply_dynamic_quant
3
  sys.path.append("../scripts") # Path of the scripts directory
4
  import config
5
  from fastapi import APIRouter, HTTPException
 
10
  import uuid
11
  from diffusers import DiffusionPipeline
12
  import torch
 
13
  from functools import lru_cache
14
 
15
  torch._inductor.config.conv_1x1_as_mm = True
16
  torch._inductor.config.coordinate_descent_tuning = True
17
  torch._inductor.config.epilogue_fusion = False
18
  torch._inductor.config.coordinate_descent_check_all_directions = True
19
+ torch._inductor.config.force_fuse_int_mm_with_mul = True
20
+ torch._inductor.config.use_mixed_mm = True
21
 
22
  router = APIRouter()
23
 
24
 
25
+ def dynamic_quant_filter_fn(mod, *args):
26
+ return (
27
+ isinstance(mod, torch.nn.Linear)
28
+ and mod.in_features > 16
29
+ and (mod.in_features, mod.out_features)
30
+ not in [
31
+ (1280, 640),
32
+ (1920, 1280),
33
+ (1920, 640),
34
+ (2048, 1280),
35
+ (2048, 2560),
36
+ (2560, 1280),
37
+ (256, 128),
38
+ (2816, 1280),
39
+ (320, 640),
40
+ (512, 1536),
41
+ (512, 256),
42
+ (512, 512),
43
+ (640, 1280),
44
+ (640, 1920),
45
+ (640, 320),
46
+ (640, 5120),
47
+ (640, 640),
48
+ (960, 320),
49
+ (960, 640),
50
+ ]
51
+ )
52
+
53
+
54
+
55
  # Utility function to convert PIL image to base64 encoded JSON
56
  def pil_to_b64_json(image):
57
  # Generate a UUID for the image
 
68
  "cuda"
69
  )
70
  pipe.load_lora_weights(adapter_name)
71
+ pipe.unload_lora_weights()
 
 
 
72
  pipe.unet.to(memory_format=torch.channels_last)
73
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
 
74
  pipe.fuse_qkv_projections()
75
+ apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
76
+ apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
77
+
78
+
79
 
80
  return pipe
81