Spaces:
Build error
Build error
| """ | |
| project @ NTO-TCP-HF | |
| created @ 2024-10-28 | |
| author @ github.com/ishworrsubedii | |
| """ | |
| import base64 | |
| import os | |
| import time | |
| from io import BytesIO | |
| import cv2 | |
| import numpy as np | |
| import replicate | |
| import requests | |
| from PIL import Image | |
| from fastapi import APIRouter, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from src.components.auto_crop import crop_transparent_image | |
| from src.components.color_extraction import ColorExtractionRMBG | |
| from src.components.title_des_gen import NecklaceProductListing | |
| from src.utils.logger import logger | |
| preprocessing_router = APIRouter() | |
| rmbg: str = os.getenv("RMBG") | |
| enhancer: str = os.getenv("ENHANCER") | |
| prod_listing_api_key: str = os.getenv("PROD_LISTING_API_KEY") | |
| color_extraction_rmbg = ColorExtractionRMBG() | |
| product_listing_obj = NecklaceProductListing(prod_listing_api_key) | |
| def replicate_bg(input): | |
| output = replicate.run( | |
| rmbg, | |
| input=input | |
| ) | |
| return output | |
| def replicate_enhancer(input): | |
| output = replicate.run( | |
| enhancer, | |
| input=input | |
| ) | |
| return output | |
| async def remove_background(image: UploadFile = File(...)): | |
| logger.info("-" * 50) | |
| logger.info(">>> REMOVE BACKGROUND STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| image_bytes = await image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, content={"error": f"Error reading image: {str(e)}", "code": 500}) | |
| try: | |
| act_img_base_64 = BytesIO() | |
| image.save(act_img_base_64, format="WEBP") | |
| image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8") | |
| image_data_uri = f"data:image/WEBP;base64,{image_bytes_}" | |
| logger.info(">>> IMAGE ENCODING COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error converting image to base64: {str(e)}", "code": 500}) | |
| try: | |
| output = replicate_bg({"image": image_data_uri}) | |
| logger.info(">>> BACKGROUND REMOVAL COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> BACKGROUND REMOVAL ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error running background removal: {str(e)}", "code": 500}) | |
| try: | |
| response = requests.get(output) | |
| base_64 = base64.b64encode(response.content).decode('utf-8') | |
| base64_prefix = "data:image/WEBP;base64," | |
| total_inference_time = round((time.time() - start_time), 2) | |
| response = { | |
| "output": f"{base64_prefix}{base_64}", | |
| "inference_time": total_inference_time, | |
| "code": 200 | |
| } | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content=response, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing response: {str(e)}", "code": 500}) | |
| async def upscale_image(image: UploadFile = File(...), scale: int = 1): | |
| logger.info("-" * 50) | |
| logger.info(">>> IMAGE UPSCALING STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| image_bytes = await image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGBA") | |
| logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, content={"error": f"Error reading image: {str(e)}", "code": 500}) | |
| try: | |
| act_img_base_64 = BytesIO() | |
| image.save(act_img_base_64, format="PNG") | |
| image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8") | |
| image_data_uri = f"data:image/png;base64,{image_bytes_}" | |
| logger.info(">>> IMAGE ENCODING COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error converting image to base64: {str(e)}", "code": 500}) | |
| try: | |
| input = { | |
| "image": image_data_uri, | |
| "scale": scale, | |
| "face_enhance": False | |
| } | |
| output = replicate_enhancer(input) | |
| logger.info(">>> IMAGE ENHANCEMENT COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE ENHANCEMENT ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error running image enhancement: {str(e)}", "code": 500}) | |
| try: | |
| response = requests.get(output) | |
| base_64 = base64.b64encode(response.content).decode('utf-8') | |
| base64_prefix = image_data_uri.split(",")[0] + "," | |
| total_inference_time = round((time.time() - start_time), 2) | |
| response = { | |
| "output": f"{base64_prefix}{base_64}", | |
| "inference_time": total_inference_time, | |
| "code": 200 | |
| } | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content=response, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing response: {str(e)}", "code": 500}) | |
| async def crop_transparent(image: UploadFile): | |
| logger.info("-" * 50) | |
| logger.info(">>> CROP TRANSPARENT STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| if not image.content_type == "image/png": | |
| logger.error(">>> INVALID FILE TYPE: NOT PNG <<<") | |
| return JSONResponse(status_code=400, | |
| content={"error": "Only PNG files are supported", "code": 400}) | |
| except Exception as e: | |
| logger.error(f">>> FILE TYPE CHECK ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error checking file type: {str(e)}", "code": 500}) | |
| try: | |
| contents = await image.read() | |
| cropped_image_bytes, metadata = crop_transparent_image(contents) | |
| logger.info(">>> IMAGE CROPPING COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE CROPPING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error cropping image: {str(e)}", "code": 500}) | |
| try: | |
| base64_image = base64.b64encode(cropped_image_bytes).decode('utf-8') | |
| base64_prefix = "data:image/png;base64," | |
| total_inference_time = round((time.time() - start_time), 2) | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content={ | |
| "status": "success", | |
| "code": 200, | |
| "data": { | |
| "image": f"{base64_prefix}{base64_image}", | |
| "metadata": metadata, | |
| "inference_time": total_inference_time | |
| } | |
| }, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing response: {str(e)}", "code": 500}) | |
| async def bg_replace(image: UploadFile = File(...), bg_image: UploadFile = File(...)): | |
| logger.info("-" * 50) | |
| logger.info(">>> BACKGROUND REPLACE STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| image_bytes = await image.read() | |
| bg_bytes = await bg_image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGBA") | |
| bg_image = Image.open(BytesIO(bg_bytes)).convert("RGB") | |
| logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error reading images: {str(e)}", "code": 500}) | |
| try: | |
| width, height = bg_image.size | |
| background = Image.fromarray(np.array(bg_image)).resize((width, height)) | |
| orig_img = Image.fromarray(np.array(image)).resize((width, height)) | |
| background.paste(orig_img, (0, 0), mask=orig_img) | |
| logger.info(">>> IMAGE PROCESSING COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error processing images: {str(e)}", "code": 500}) | |
| try: | |
| act_img_base_64 = BytesIO() | |
| background.save(act_img_base_64, format="WEBP") | |
| image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8") | |
| image_data_uri = f"data:image/webp;base64,{image_bytes_}" | |
| total_inference_time = round((time.time() - start_time), 2) | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content={ | |
| "output": image_data_uri, | |
| "code": 200, | |
| "inference_time": total_inference_time | |
| }, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error creating response: {str(e)}", "code": 500}) | |
| async def remove_background_color_extraction(image: UploadFile = File(...), | |
| hex_color: str = "#FFFFFF", | |
| threshold: int = 30): | |
| logger.info("-" * 50) | |
| logger.info(">>> COLOR EXTRACTION STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| image_bytes = await image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGBA") | |
| image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error reading image: {str(e)}", "code": 500}) | |
| try: | |
| result = color_extraction_rmbg.extract_color(image, hex_color, threshold) | |
| result = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_RGB2BGRA)).convert("RGBA") | |
| logger.info(">>> COLOR EXTRACTION COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(f">>> COLOR EXTRACTION ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error extracting colors: {str(e)}", "code": 500}) | |
| try: | |
| act_img_base_64 = BytesIO() | |
| result.save(act_img_base_64, format="PNG") | |
| image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8") | |
| image_data_uri = f"data:image/png;base64,{image_bytes_}" | |
| total_inference_time = round((time.time() - start_time), 2) | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content={ | |
| "output": image_data_uri, | |
| "code": 200, | |
| "inference_time": total_inference_time | |
| }, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error creating response: {str(e)}", "code": 500}) | |
| async def product_title_description_generator(image: UploadFile = File(...)): | |
| logger.info("-" * 50) | |
| logger.info(">>> TITLE DESCRIPTION GENERATION STARTED <<<") | |
| start_time = time.time() | |
| try: | |
| image_bytes = await image.read() | |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<") | |
| except Exception as e: | |
| logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error reading image: {str(e)}", "code": 500}) | |
| try: | |
| result = product_listing_obj.gen_title_desc(image=image) | |
| title = result.split("Title:")[1].split("Description:")[0] | |
| description = result.split("Description:")[1] | |
| logger.info(">>> TITLE AND DESCRIPTION GENERATION COMPLETED <<<") | |
| except Exception as e: | |
| logger.error(">>> TITLE DESCRIPTION GENERATION ERROR <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": "Please make sure the image is clear and necklaces are visible", | |
| "code": 500}) | |
| try: | |
| total_inference_time = round((time.time() - start_time), 2) | |
| logger.info(">>> RESPONSE PREPARATION COMPLETED <<<") | |
| logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<") | |
| logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<") | |
| logger.info("-" * 50) | |
| return JSONResponse(content={ | |
| "code": 200, | |
| "title": title, | |
| "description": description, | |
| "inference_time": total_inference_time | |
| }, status_code=200) | |
| except Exception as e: | |
| logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<") | |
| return JSONResponse(status_code=500, | |
| content={"error": f"Error creating response: {str(e)}", "code": 500}) | |