VikramSingh178 commited on
Commit
11de6ea
1 Parent(s): b08da72

chore: remove unused logfire dependency

Browse files

Former-commit-id: fa272ae77f88e2df6fbbccf957dc6e5999f916ea

Files changed (2) hide show
  1. api/requirements.txt +0 -1
  2. api/routers/painting.py +85 -175
api/requirements.txt CHANGED
@@ -4,7 +4,6 @@ diffusers==0.27.2
4
  fastapi==0.111.0
5
  hydra-core==1.3.2
6
  lightning==2.2.3
7
- logfire==0.42.0
8
  Pillow==10.3.0
9
  pydantic==2.7.4
10
  torch
 
4
  fastapi==0.111.0
5
  hydra-core==1.3.2
6
  lightning==2.2.3
 
7
  Pillow==10.3.0
8
  pydantic==2.7.4
9
  torch
api/routers/painting.py CHANGED
@@ -1,64 +1,48 @@
1
  import os
2
  import uuid
3
- from typing import List, Tuple, Any, Dict
 
4
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
5
- from pydantic import BaseModel, Field
6
- from PIL import Image
7
  import lightning.pytorch as pl
8
- from scripts.api_utils import pil_to_s3_json, pil_to_b64_json, ImageAugmentation, accelerator
9
- from scripts.inpainting_pipeline import AutoPaintingPipeline, load_pipeline
10
- from hydra import compose, initialize
11
  from async_batcher.batcher import AsyncBatcher
12
- import json
13
  from functools import lru_cache
 
14
  pl.seed_everything(42)
15
  router = APIRouter()
16
 
17
-
18
- with initialize(version_base=None, config_path="../../configs"):
19
- cfg = compose(config_name="inpainting")
20
-
21
- # Load the inpainting pipeline
22
  @lru_cache(maxsize=1)
23
- def load_pipeline_wrapper():
24
- """
25
- Load the inpainting pipeline with the specified configuration.
26
 
27
- Returns:
28
- pipeline: The loaded inpainting pipeline.
29
  """
30
- pipeline = load_pipeline(cfg.model, accelerator(), enable_compile=True)
31
- return pipeline
32
- inpainting_pipeline = load_pipeline_wrapper()
33
-
34
- class InpaintingRequest(BaseModel):
35
- """
36
- Model representing a request for inpainting inference.
37
  """
38
- prompt: str = Field(..., description="Prompt text for inference")
39
- negative_prompt: str = Field(..., description="Negative prompt text for inference")
40
- num_inference_steps: int = Field(..., description="Number of inference steps")
41
- strength: float = Field(..., description="Strength of the inference")
42
- guidance_scale: float = Field(..., description="Guidance scale for inference")
43
- mode: str = Field(..., description="Mode for output ('b64_json' or 's3_json')")
44
- num_images: int = Field(..., description="Number of images to generate")
45
- use_augmentation: bool = Field(True, description="Whether to use image augmentation")
46
-
47
- class InpaintingBatchRequestModel(BaseModel):
 
 
 
48
  """
49
- Model representing a batch request for inpainting inference.
50
  """
51
- requests: List[InpaintingRequest]
52
 
53
  async def save_image(image: UploadFile) -> str:
54
  """
55
  Save an uploaded image to a temporary file and return the file path.
56
-
57
- Args:
58
- image (UploadFile): The uploaded image file.
59
-
60
- Returns:
61
- str: File path where the image is saved.
62
  """
63
  file_name = f"{uuid.uuid4()}.png"
64
  file_path = os.path.join("/tmp", file_name)
@@ -66,149 +50,75 @@ async def save_image(image: UploadFile) -> str:
66
  f.write(await image.read())
67
  return file_path
68
 
69
- def augment_image(image_path, target_width, target_height, roi_scale, segmentation_model_name, detection_model_name):
70
- """
71
- Augment an image by extending its dimensions and generating masks.
72
-
73
- Args:
74
- image_path (str): Path to the image file.
75
- target_width (int): Target width for augmentation.
76
- target_height (int): Target height for augmentation.
77
- roi_scale (float): Scale factor for region of interest.
78
- segmentation_model_name (str): Name of the segmentation model.
79
- detection_model_name (str): Name of the detection model.
80
-
81
- Returns:
82
- Tuple[Image.Image, Image.Image]: Augmented image and inverted mask.
83
- """
84
- image = Image.open(image_path)
85
- image_augmentation = ImageAugmentation(target_width, target_height, roi_scale)
86
- image = image_augmentation.extend_image(image)
87
- mask = image_augmentation.generate_mask_from_bbox(image, segmentation_model_name, detection_model_name)
88
- inverted_mask = image_augmentation.invert_mask(mask)
89
- return image, inverted_mask
90
-
91
- def run_inference(cfg, image_path: str, request: InpaintingRequest):
92
- """
93
- Run inference using an inpainting pipeline on an image.
94
-
95
- Args:
96
- cfg (dict): Configuration dictionary.
97
- image_path (str): Path to the image file.
98
- request (InpaintingRequest): Pydantic model containing inference parameters.
99
-
100
- Returns:
101
- dict: Resulting image in the specified mode ('b64_json' or 's3_json').
102
-
103
- Raises:
104
- ValueError: If an invalid mode is provided.
105
- """
106
- if request.use_augmentation:
107
- image, mask_image = augment_image(image_path,
108
- cfg['target_width'],
109
- cfg['target_height'],
110
- cfg['roi_scale'],
111
- cfg['segmentation_model'],
112
- cfg['detection_model'])
113
- else:
114
- image = Image.open(image_path)
115
- mask_image = None
116
-
117
- painting_pipeline = AutoPaintingPipeline(
118
- pipeline=inpainting_pipeline,
119
- image=image,
120
- mask_image=mask_image,
121
- target_height=cfg['target_height'],
122
- target_width=cfg['target_width']
123
  )
124
- output = painting_pipeline.run_inference(prompt=request.prompt,
125
- negative_prompt=request.negative_prompt,
126
- num_inference_steps=request.num_inference_steps,
127
- strength=request.strength,
128
- guidance_scale=request.guidance_scale,
129
- num_images=request.num_images)
130
- if request.mode == "s3_json":
131
- return pil_to_s3_json(output, file_name="output.png")
132
- elif request.mode == "b64_json":
133
- return pil_to_b64_json(output)
134
- else:
135
- raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.")
136
-
137
- class InpaintingBatcher(AsyncBatcher):
138
- async def process_batch(self, batch: Tuple[List[str], List[InpaintingRequest]]) -> List[Dict[str, Any]]:
139
- """
140
- Process a batch of images and requests for inpainting inference.
141
 
142
- Args:
143
- batch (Tuple[List[str], List[InpaintingRequest]]): Tuple of image paths and corresponding requests.
144
-
145
- Returns:
146
- List[Dict[str, Any]]: List of resulting images in the specified mode ('b64_json' or 's3_json').
147
- """
148
- image_paths, requests = batch
149
- results = []
150
- for image_path, request in zip(image_paths, requests):
151
- result = run_inference(cfg, image_path, request)
152
- results.append(result)
153
- return results
154
-
155
- @router.post("/inpainting")
156
- async def inpainting_inference(
157
  image: UploadFile = File(...),
158
- request_data: str = Form(...),
159
  ):
160
- """
161
- Handle POST request for inpainting inference.
162
-
163
- Args:
164
- image (UploadFile): Uploaded image file.
165
- request_data (str): JSON string of the request parameters.
166
-
167
- Returns:
168
- dict: Resulting image in the specified mode ('b64_json' or 's3_json').
169
-
170
- Raises:
171
- HTTPException: If there is an error during image processing.
172
- """
173
  try:
 
 
 
174
  image_path = await save_image(image)
175
- request_dict = json.loads(request_data)
176
- request = InpaintingRequest(**request_dict)
177
- result = run_inference(cfg, image_path, request)
178
- return result
 
 
 
 
 
 
 
179
  except Exception as e:
180
  raise HTTPException(status_code=500, detail=str(e))
181
 
182
- @router.post("/inpainting/batch")
183
- async def inpainting_batch_inference(
184
- images: List[UploadFile] = File(...),
185
- request_data: str = Form(...),
186
- ):
187
- """
188
- Handle POST request for batch inpainting inference.
189
-
190
- Args:
191
- images (List[UploadFile]): List of uploaded image files.
192
- request_data (str): JSON string of the request parameters.
193
-
194
- Returns:
195
- List[dict]: List of resulting images in the specified mode ('b64_json' or 's3_json').
196
 
197
- Raises:
198
- HTTPException: If there is an error during image processing.
199
- """
200
  try:
201
- request_dict = json.loads(request_data)
202
- batch_request = InpaintingBatchRequestModel(**request_dict)
203
- requests = batch_request.requests
204
-
205
- if len(images) != len(requests):
206
- raise HTTPException(status_code=400, detail="The number of images and requests must match.")
207
-
208
- batcher = InpaintingBatcher(max_batch_size=64)
209
- image_paths = [await save_image(image) for image in images]
210
- results = batcher.process_batch((image_paths, requests))
211
-
212
- return results
 
213
  except Exception as e:
214
  raise HTTPException(status_code=500, detail=str(e))
 
1
  import os
2
  import uuid
3
+ import json
4
+ from typing import List
5
  from fastapi import APIRouter, File, UploadFile, HTTPException, Form
6
+ from pydantic import BaseModel, Field, ValidationError
 
7
  import lightning.pytorch as pl
8
+ from scripts.api_utils import pil_to_b64_json
9
+ from scripts.outpainting import ControlNetZoeDepthOutpainting
 
10
  from async_batcher.batcher import AsyncBatcher
 
11
  from functools import lru_cache
12
+
13
  pl.seed_everything(42)
14
  router = APIRouter()
15
 
 
 
 
 
 
16
  @lru_cache(maxsize=1)
17
+ def load_pipeline():
18
+ outpainting_pipeline = ControlNetZoeDepthOutpainting(target_size=(1024, 1024))
19
+ return outpainting_pipeline
20
 
21
+ class OutpaintingRequest(BaseModel):
 
22
  """
23
+ Model representing a request for outpainting inference.
 
 
 
 
 
 
24
  """
25
+ controlnet_prompt: str = Field(...)
26
+ controlnet_negative_prompt: str = Field(...)
27
+ controlnet_conditioning_scale: float = Field(...)
28
+ controlnet_guidance_scale: float = Field(...)
29
+ controlnet_num_inference_steps: int = Field(...)
30
+ controlnet_guidance_end: float = Field(...)
31
+ inpainting_prompt: str = Field(...)
32
+ inpainting_negative_prompt: str = Field(...)
33
+ inpainting_guidance_scale: float = Field(...)
34
+ inpainting_strength: float = Field(...)
35
+ inpainting_num_inference_steps: int = Field(...)
36
+
37
+ class OutpaintingBatchRequestModel(BaseModel):
38
  """
39
+ Model representing a batch request for outpainting inference.
40
  """
41
+ requests: List[OutpaintingRequest]
42
 
43
  async def save_image(image: UploadFile) -> str:
44
  """
45
  Save an uploaded image to a temporary file and return the file path.
 
 
 
 
 
 
46
  """
47
  file_name = f"{uuid.uuid4()}.png"
48
  file_path = os.path.join("/tmp", file_name)
 
50
  f.write(await image.read())
51
  return file_path
52
 
53
+ def run_inference(image_path: str, request: OutpaintingRequest):
54
+ pipeline = load_pipeline()
55
+ result = pipeline.run_pipeline(
56
+ image_path,
57
+ controlnet_prompt=request.controlnet_prompt,
58
+ controlnet_negative_prompt=request.controlnet_negative_prompt,
59
+ controlnet_conditioning_scale=request.controlnet_conditioning_scale,
60
+ controlnet_guidance_scale=request.controlnet_guidance_scale,
61
+ controlnet_num_inference_steps=request.controlnet_num_inference_steps,
62
+ controlnet_guidance_end=request.controlnet_guidance_end,
63
+ inpainting_prompt=request.inpainting_prompt,
64
+ inpainting_negative_prompt=request.inpainting_negative_prompt,
65
+ inpainting_guidance_scale=request.inpainting_guidance_scale,
66
+ inpainting_strength=request.inpainting_strength,
67
+ inpainting_num_inference_steps=request.inpainting_num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ @router.post("/outpaint")
72
+ async def outpaint(
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  image: UploadFile = File(...),
74
+ request: str = Form(...)
75
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
+ request_dict = json.loads(request)
78
+ outpainting_request = OutpaintingRequest(**request_dict)
79
+
80
  image_path = await save_image(image)
81
+ result = run_inference(image_path, outpainting_request)
82
+
83
+ result_json = pil_to_b64_json(result)
84
+
85
+ os.remove(image_path)
86
+
87
+ return {"result": result_json}
88
+ except json.JSONDecodeError:
89
+ raise HTTPException(status_code=400, detail="Invalid JSON in request data")
90
+ except ValidationError as e:
91
+ raise HTTPException(status_code=422, detail=str(e))
92
  except Exception as e:
93
  raise HTTPException(status_code=500, detail=str(e))
94
 
95
+ class OutpaintingBatcher(AsyncBatcher):
96
+ async def process_batch(self, batch):
97
+ results = []
98
+ for image, request in batch:
99
+ image_path = await save_image(image)
100
+ try:
101
+ result = run_inference(image_path, request)
102
+ results.append(result)
103
+ finally:
104
+ os.remove(image_path)
105
+ return results
 
 
 
106
 
107
+ @router.post("/batch_outpaint")
108
+ async def batch_outpaint(images: List[UploadFile] = File(...), batch_request: str = Form(...)):
 
109
  try:
110
+ batch_request_dict = json.loads(batch_request)
111
+ batch_outpainting_request = OutpaintingBatchRequestModel(**batch_request_dict)
112
+
113
+ batcher = OutpaintingBatcher(max_queue_size=64)
114
+ results = await batcher.process_batch(list(zip(images, batch_outpainting_request.requests)))
115
+
116
+ result_jsons = [pil_to_b64_json(result) for result in results]
117
+
118
+ return {"results": result_jsons}
119
+ except json.JSONDecodeError:
120
+ raise HTTPException(status_code=400, detail="Invalid JSON in batch request data")
121
+ except ValidationError as e:
122
+ raise HTTPException(status_code=422, detail=str(e))
123
  except Exception as e:
124
  raise HTTPException(status_code=500, detail=str(e))