VikramSingh178 commited on
Commit
1979816
β€’
1 Parent(s): 654965f

Feature added an Option for Either Generating an S3 url or b64 json

Browse files
product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc CHANGED
Binary files a/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc and b/product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc differ
 
product_diffusion_api/routers/sdxl_text_to_image.py CHANGED
@@ -1,5 +1,6 @@
1
  import sys
2
  from torchao.quantization import apply_dynamic_quant
 
3
  sys.path.append("../scripts") # Path of the scripts directory
4
  import config
5
  from fastapi import APIRouter, HTTPException
@@ -11,6 +12,9 @@ import uuid
11
  from diffusers import DiffusionPipeline
12
  import torch
13
  from functools import lru_cache
 
 
 
14
 
15
  torch._inductor.config.conv_1x1_as_mm = True
16
  torch._inductor.config.coordinate_descent_tuning = True
@@ -52,18 +56,30 @@ def dynamic_quant_filter_fn(mod, *args):
52
 
53
 
54
 
55
- # Utility function to convert PIL image to base64 encoded JSON
 
56
  def pil_to_b64_json(image):
57
- # Generate a UUID for the image
58
  image_id = str(uuid.uuid4())
59
  buffered = BytesIO()
60
  image.save(buffered, format="PNG")
61
  b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
62
  return {"image_id": image_id, "b64_image": b64_image}
63
 
64
- def upload_pil_to_s3(image):
65
- image
66
-
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  @lru_cache(maxsize=1)
69
  def load_pipeline(model_name, adapter_name):
@@ -77,9 +93,6 @@ def load_pipeline(model_name, adapter_name):
77
  pipe.fuse_qkv_projections()
78
  apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
79
  apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
80
-
81
-
82
-
83
  return pipe
84
 
85
 
@@ -125,12 +138,17 @@ class SDXLLoraInference:
125
  self.num_inference_steps = num_inference_steps
126
  self.guidance_scale = guidance_scale
127
 
128
- def run_inference(self):
129
  """
130
  Runs the inference process and returns the generated image.
131
 
 
 
 
 
 
132
  Returns:
133
- str: The generated image in base64-encoded JSON format.
134
  """
135
  image = self.pipe(
136
  prompt=self.prompt,
@@ -139,7 +157,14 @@ class SDXLLoraInference:
139
  negative_prompt=self.negative_prompt,
140
  num_images_per_prompt=self.num_images,
141
  ).images[0]
142
- return pil_to_b64_json(image)
 
 
 
 
 
 
 
143
 
144
 
145
  # Input format for single request
@@ -170,7 +195,7 @@ async def sdxl_v0_lora_inference(data: InputFormat):
170
  return output_json
171
 
172
 
173
- # Endpoint for batch requests
174
  @router.post("/sdxl_v0_lora_inference/batch")
175
  async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
176
  """
 
1
  import sys
2
  from torchao.quantization import apply_dynamic_quant
3
+
4
  sys.path.append("../scripts") # Path of the scripts directory
5
  import config
6
  from fastapi import APIRouter, HTTPException
 
12
  from diffusers import DiffusionPipeline
13
  import torch
14
  from functools import lru_cache
15
+ from s3_manager import S3ManagerService
16
+ from PIL import Image
17
+ import io
18
 
19
  torch._inductor.config.conv_1x1_as_mm = True
20
  torch._inductor.config.coordinate_descent_tuning = True
 
56
 
57
 
58
 
59
+
60
+
61
  def pil_to_b64_json(image):
 
62
  image_id = str(uuid.uuid4())
63
  buffered = BytesIO()
64
  image.save(buffered, format="PNG")
65
  b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
66
  return {"image_id": image_id, "b64_image": b64_image}
67
 
68
+
69
+ def pil_to_s3_json(image: Image.Image, file_name: str) -> str:
70
+ image_id = str(uuid.uuid4())
71
+ s3_uploader = S3ManagerService()
72
+ image_bytes = io.BytesIO()
73
+ image.save(image_bytes, format="PNG")
74
+ image_bytes.seek(0)
75
+
76
+ unique_file_name = s3_uploader.generate_unique_file_name(file_name)
77
+ s3_uploader.upload_file(image_bytes, unique_file_name)
78
+ signed_url = s3_uploader.generate_signed_url(
79
+ unique_file_name, exp=43200
80
+ ) # 12 hours
81
+ return {"image_id": image_id, "url": signed_url}
82
+
83
 
84
  @lru_cache(maxsize=1)
85
  def load_pipeline(model_name, adapter_name):
 
93
  pipe.fuse_qkv_projections()
94
  apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
95
  apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
 
 
 
96
  return pipe
97
 
98
 
 
138
  self.num_inference_steps = num_inference_steps
139
  self.guidance_scale = guidance_scale
140
 
141
+ def run_inference(self, mode: str = "b64_json") -> str:
142
  """
143
  Runs the inference process and returns the generated image.
144
 
145
+ Parameters:
146
+ mode (str): The mode for returning the generated image.
147
+ Possible values: "b64_json", "s3_json".
148
+ Defaults to "b64_json".
149
+
150
  Returns:
151
+ str: The generated image in the specified format.
152
  """
153
  image = self.pipe(
154
  prompt=self.prompt,
 
157
  negative_prompt=self.negative_prompt,
158
  num_images_per_prompt=self.num_images,
159
  ).images[0]
160
+
161
+ if mode == "s3_json":
162
+ s3_url = pil_to_s3_json(image)
163
+ return pil_to_s3_json(image, s3_url)
164
+ elif mode == "b64_json":
165
+ return pil_to_b64_json(image)
166
+ else:
167
+ raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
168
 
169
 
170
  # Input format for single request
 
195
  return output_json
196
 
197
 
198
+
199
  @router.post("/sdxl_v0_lora_inference/batch")
200
  async def sdxl_v0_lora_inference_batch(data: BatchInputFormat):
201
  """
scripts/__pycache__/s3_manager.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
scripts/s3_manager.py CHANGED
@@ -1,33 +1,33 @@
1
  import base64
2
  import io
 
3
  import boto3
4
  from botocore.config import Config
5
  import random
6
  import string
7
  from dotenv import load_dotenv
8
 
9
- env = load_dotenv('../config.env')
 
10
 
11
-
12
- class ImageService:
13
- def _init_(self):
14
  self.s3 = boto3.client(
15
  "s3",
16
  config=Config(signature_version="s3v4"),
17
- aws_access_key_id=env.AWS_ACCESS_KEY_ID,
18
- aws_secret_access_key=env.AWS_SECRET_ACCESS_KEY,
19
- region_name=env.AWS_REGION,
20
  )
21
-
22
- def generate_signed_url(self, file_name: str, exp: int = 1800) -> str:
23
  return self.s3.generate_presigned_url(
24
  "get_object",
25
- Params={"Bucket": env.AWS_BUCKET_NAME, "Key": file_name},
26
  ExpiresIn=exp,
27
  )
28
 
29
- def generate_unique_file_name(self, file) -> str:
30
- file_name = file.filename
31
  random_string = "".join(
32
  random.choices(string.ascii_uppercase + string.digits, k=10)
33
  )
@@ -36,7 +36,7 @@ class ImageService:
36
  return f"{file_real_name}-{random_string}.{file_extension}"
37
 
38
  def upload_file(self, file, file_name) -> str:
39
- self.s3.upload_fileobj(file, env.AWS_BUCKET_NAME, file_name)
40
  return file_name
41
 
42
  def upload_base64_file(self, base64_file: str, file_name: str) -> str:
@@ -45,5 +45,9 @@ class ImageService:
45
  def get_object(self, file_name: str, bucket: str):
46
  try:
47
  return self.s3.get_object(Bucket=bucket, Key=file_name)
48
- except: # noqa: E722
 
 
 
 
49
  return None
 
1
  import base64
2
  import io
3
+ import os
4
  import boto3
5
  from botocore.config import Config
6
  import random
7
  import string
8
  from dotenv import load_dotenv
9
 
10
+ # Load environment variables from the .env file
11
+ load_dotenv('../config.env')
12
 
13
+ class S3ManagerService:
14
+ def __init__(self):
 
15
  self.s3 = boto3.client(
16
  "s3",
17
  config=Config(signature_version="s3v4"),
18
+ aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
19
+ aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
20
+ region_name=os.getenv('AWS_REGION'),
21
  )
22
+
23
+ def generate_signed_url(self, file_name: str, exp: int = 43200) -> str: # 43200 seconds = 12 hours
24
  return self.s3.generate_presigned_url(
25
  "get_object",
26
+ Params={"Bucket": os.getenv('AWS_BUCKET_NAME'), "Key": file_name},
27
  ExpiresIn=exp,
28
  )
29
 
30
+ def generate_unique_file_name(self, file_name: str) -> str:
 
31
  random_string = "".join(
32
  random.choices(string.ascii_uppercase + string.digits, k=10)
33
  )
 
36
  return f"{file_real_name}-{random_string}.{file_extension}"
37
 
38
  def upload_file(self, file, file_name) -> str:
39
+ self.s3.upload_fileobj(file, os.getenv('AWS_BUCKET_NAME'), file_name)
40
  return file_name
41
 
42
  def upload_base64_file(self, base64_file: str, file_name: str) -> str:
 
45
  def get_object(self, file_name: str, bucket: str):
46
  try:
47
  return self.s3.get_object(Bucket=bucket, Key=file_name)
48
+ except self.s3.exceptions.NoSuchKey:
49
+ print(f"The file {file_name} does not exist in the bucket {bucket}.")
50
+ return None
51
+ except Exception as e:
52
+ print(f"An error occurred: {e}")
53
  return None