VikramSingh178 commited on
Commit
07da480
1 Parent(s): ed57be9

Add new endpoints for product diffusion API and SDXL-LoRA inference

Browse files
product_diffusion_api/__pycache__/endpoints.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
product_diffusion_api/endpoints.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
 
4
 
5
 
@@ -14,7 +15,7 @@ app.add_middleware(
14
 
15
  )
16
 
17
- #app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
18
 
19
 
20
 
@@ -30,4 +31,8 @@ async def root():
30
  'github': 'https://github.com/vikramxD'
31
  },
32
  'license': 'MIT',
33
- }
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from routers import sdxl_text_to_image
4
 
5
 
6
 
 
15
 
16
  )
17
 
18
+ app.include_router(sdxl_text_to_image.router, prefix='/api/v1/product-diffusion')
19
 
20
 
21
 
 
31
  'github': 'https://github.com/vikramxD'
32
  },
33
  'license': 'MIT',
34
+ }
35
+
36
+ @app.get("/health")
37
+ def check_health():
38
+ return {"status": "ok"}
product_diffusion_api/routers/__pycache__/sdxl_text_to_image.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
product_diffusion_api/routers/sdxl_text_to_image.py CHANGED
@@ -1,37 +1,106 @@
1
  from diffusers import DiffusionPipeline
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class SDXLLoraInference:
6
  """
7
  Class for running inference using the SDXL-LoRA model to generate stunning product photographs.
8
-
9
  Args:
10
  prompt (str): The input prompt for generating the product photograph.
11
  num_inference_steps (int): The number of inference steps to perform.
12
  guidance_scale (float): The scale factor for guidance during inference.
13
  """
14
- def __init__(self, prompt: str, num_inference_steps: int, guidance_scale: float) -> None:
 
 
 
15
  self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
16
- self.pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
 
 
17
  self.pipe.to("cuda")
18
  self.pipe.load_lora_weights(self.model_path)
19
  self.num_inference_steps = num_inference_steps
20
  self.guidance_scale = guidance_scale
21
  self.prompt = prompt
22
-
 
23
 
24
  def run_inference(self):
25
  """
26
  Runs inference using the SDXL-LoRA model to generate a stunning product photograph.
27
-
28
  Returns:
29
  images: The generated product photograph(s).
30
  """
31
-
32
  prompt = self.prompt
33
- images = self.pipe(prompt, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale).images
34
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- inference = SDXLLoraInference(num_inference_steps=100, guidance_scale=2.5)
37
- inference.run_inference(prompt= "A stunning 4k Shot of a Balenciaga X Anime Hoodie with a person wearing it in a party" )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from diffusers import DiffusionPipeline
2
  import torch
3
+ from fastapi import APIRouter
4
+ from pydantic import BaseModel
5
+ import json
6
+ import base64
7
+ from PIL import Image
8
+ from io import BytesIO
9
+
10
+
11
+
12
+ router = APIRouter()
13
+
14
+
15
+
16
+ def pil_to_b64_json(image):
17
+ buffered = BytesIO()
18
+ image.save(buffered, format="PNG")
19
+ b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
20
+ json_data = {"b64_image": b64_image}
21
+ return json_data
22
 
23
 
24
  class SDXLLoraInference:
25
  """
26
  Class for running inference using the SDXL-LoRA model to generate stunning product photographs.
27
+
28
  Args:
29
  prompt (str): The input prompt for generating the product photograph.
30
  num_inference_steps (int): The number of inference steps to perform.
31
  guidance_scale (float): The scale factor for guidance during inference.
32
  """
33
+
34
+ def __init__(
35
+ self, prompt: str, negative_prompt:str,num_images:int ,num_inference_steps: int, guidance_scale: float
36
+ ) -> None:
37
  self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption"
38
+ self.pipe = DiffusionPipeline.from_pretrained(
39
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
40
+ )
41
  self.pipe.to("cuda")
42
  self.pipe.load_lora_weights(self.model_path)
43
  self.num_inference_steps = num_inference_steps
44
  self.guidance_scale = guidance_scale
45
  self.prompt = prompt
46
+ self.negative_prompt = negative_prompt
47
+ self.num_images = num_images
48
 
49
  def run_inference(self):
50
  """
51
  Runs inference using the SDXL-LoRA model to generate a stunning product photograph.
52
+
53
  Returns:
54
  images: The generated product photograph(s).
55
  """
56
+
57
  prompt = self.prompt
58
+ negative_prompt = self.negative_prompt
59
+ num_images = self.num_images
60
+
61
+ image = self.pipe(
62
+ prompt=prompt,
63
+ negative_prompt=negative_prompt,
64
+ num_inference_steps=self.num_inference_steps,
65
+ guidance_scale=self.guidance_scale,
66
+ num_images_per_prompt=num_images
67
+ ).images[0]
68
+ image_json = pil_to_b64_json(image)
69
+ return image_json
70
+
71
+
72
+
73
+ class InputFormat(BaseModel):
74
+ prompt : str
75
+ negative_prompt : str
76
+ num_images : int
77
+ num_inference_steps : int
78
+ guidance_scale : float
79
+
80
+
81
+
82
+
83
+ @router.post("/sdxl_v0_lora_inference")
84
+ async def sdxl_v0_lora_inference(data: InputFormat):
85
+ """
86
+ Perform SDXL V0 LoRa inference.
87
+
88
+ Args:
89
+ data (InputFormat): The input data containing the prompt, number of inference steps, and guidance scale.
90
 
91
+ Returns:
92
+ The output of the inference.
93
+ """
94
+ prompt = data.prompt
95
+ negative_prompt = data.negative_prompt,
96
+ num_images = data.num_images
97
+ num_inference_steps = data.num_inference_steps
98
+ guidance_scale = data.guidance_scale
99
+ inference = SDXLLoraInference(prompt,negative_prompt, num_inference_steps, guidance_scale,num_images)
100
+ output_json = inference.run_inference()
101
+ return output_json
102
+
103
+
104
+
105
+
106
+
requirements.txt CHANGED
@@ -19,3 +19,4 @@ tensorboard
19
  Jinja2
20
  datasets
21
  peft
 
 
19
  Jinja2
20
  datasets
21
  peft
22
+ async-batcher