zestBG / app.py
krunakuamar's picture
Update app.py
98ec77e
import base64
import imghdr
import os
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from ultralytics.yolo.utils.ops import scale_image
import asyncio
from fastapi import FastAPI, File, UploadFile, Request, Response
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# from mangum import Mangum
from argparse import ArgumentParser
import lama_cleaner.server2 as server
from lama_cleaner.helper import (
load_img,
)
# os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/directory"
app = FastAPI()
# handler = Mangum(app)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes:
"""
Args:
image_numpy: numpy image
ext: image extension
Returns:
image bytes
"""
data = cv2.imencode(
f".{ext}",
image_numpy,
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
)[1].tobytes()
return data
def get_image_ext(img_bytes):
"""
Args:
img_bytes: image bytes
Returns:
image extension
"""
if not img_bytes:
raise ValueError("Empty input")
header = img_bytes[:32]
w = imghdr.what("", header)
if w is None:
w = "jpeg"
return w
def predict_on_image(model, img, conf, retina_masks):
"""
Args:
model: YOLOv8 model
img: image (C, H, W)
conf: confidence threshold
retina_masks: use retina masks or not
Returns:
boxes: box with xyxy format, (N, 4)
masks: masks, (N, H, W)
cls: class of masks, (N, )
probs: confidence score, (N, 1)
"""
with torch.no_grad():
result = model(img, conf=conf, retina_masks=retina_masks, scale=1)[0]
boxes, masks, cls, probs = None, None, None, None
if result.boxes.cls.size(0) > 0:
# detection
cls = result.boxes.cls.cpu().numpy().astype(np.int32)
probs = result.boxes.conf.cpu().numpy() # confidence score, (N, 1)
boxes = result.boxes.xyxy.cpu().numpy() # box with xyxy format, (N, 4)
# segmentation
masks = result.masks.masks.cpu().numpy() # masks, (N, H, W)
masks = np.transpose(masks, (1, 2, 0)) # masks, (H, W, N)
# rescale masks to original image
masks = scale_image(masks.shape[:2], masks, result.masks.orig_shape)
masks = np.transpose(masks, (2, 0, 1)) # masks, (N, H, W)
return boxes, masks, cls, probs
def overlay(image, mask, color, alpha, id, resize=None):
"""Overlays a binary mask on an image.
Args:
image: Image to be overlayed on.
mask: Binary mask to overlay.
color: Color to use for the mask.
alpha: Opacity of the mask.
id: id of the mask
resize: Resize the image to this size. If None, no resizing is performed.
Returns:
The overlayed image.
"""
color = color[::-1]
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
colored_mask = np.moveaxis(colored_mask, 0, -1)
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
image_overlay = masked.filled()
imgray = cv2.cvtColor(image_overlay, cv2.COLOR_BGR2GRAY)
contour_thickness = 8
_, thresh = cv2.threshold(imgray, 255, 255, 255)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
imgray = cv2.cvtColor(imgray, cv2.COLOR_GRAY2BGR)
imgray = cv2.drawContours(imgray, contours, -1, (255, 255, 255), contour_thickness)
imgray = np.where(imgray.any(-1, keepdims=True), (46, 36, 225), 0)
if resize is not None:
image = cv2.resize(image.transpose(1, 2, 0), resize)
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
return imgray
async def process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls):
"""Process the mask of the image.
Args:
idx: index of the mask
mask_i: mask of the image
boxes: box with xyxy format, (N, 4)
probs: confidence score, (N, 1)
yolo_model: YOLOv8 model
blank_image: blank image
cls: class of masks, (N, )
Returns:
dictionary_seg: dictionary of the mask of the image
"""
dictionary_seg = {}
maskwith_back = overlay(blank_image, mask_i, color=(255, 155, 155), alpha=0.5, id=idx)
alpha = np.sum(maskwith_back, axis=-1) > 0
alpha = np.uint8(alpha * 255)
maskwith_back = np.dstack((maskwith_back, alpha))
imgencode = await asyncio.get_running_loop().run_in_executor(None, cv2.imencode, '.png', maskwith_back)
mask = base64.b64encode(imgencode[1]).decode('utf-8')
dictionary_seg["confi"] = f'{probs[idx] * 100:.2f}'
dictionary_seg["boxe"] = [int(item) for item in list(boxes[idx])]
dictionary_seg["mask"] = mask
dictionary_seg["cls"] = str(yolo_model.names[cls[idx]])
return dictionary_seg
# @app.middleware("http")
# async def check_auth_header(request: Request, call_next):
# token = request.headers.get('Authorization')
# if token != os.environ.get("SECRET"):
# return JSONResponse(content={'error': 'Authorization header missing or incorrect.'}, status_code=403)
# else:
# response = await call_next(request)
# return response
@app.post("/api/mask")
async def detect_mask(file: UploadFile = File()):
"""
Detects masks in an image uploaded via a POST request and returns a JSON response containing the details of the detected masks.
Args:
None
Parameters:
- file: a file object containing the input image
Returns:
A JSON response containing the details of the detected masks:
- code: 200 if objects were detected, 500 if no objects were detected
- msg: a message indicating whether objects were detected or not
- data: a list of dictionaries, where each dictionary contains the following keys:
- confi: the confidence level of the detected object
- boxe: a list containing the coordinates of the bounding box of the detected object
- mask: the mask of the detected object encoded in base64
- cls: the class of the detected object
Raises:
500: No objects detected
"""
file = await file.read()
img, _ = load_img(file)
# predict by YOLOv8
boxes, masks, cls, probs = predict_on_image(yolo_model, img, conf=0.55, retina_masks=True)
if boxes is None:
return {'code': 500, 'msg': 'No objects detected'}
# overlay masks on original image
blank_image = np.zeros(img.shape, dtype=np.uint8)
data = []
coroutines = [process_mask(idx, mask_i, boxes, probs, yolo_model, blank_image, cls) for idx, mask_i in
enumerate(masks)]
results = await asyncio.gather(*coroutines)
for result in results:
data.append(result)
return {'code': 200, 'msg': "object detected", 'data': data}
@app.post("/api/lama/paint")
async def paint(img: UploadFile = File(), mask: UploadFile = File()):
"""
Endpoint to process an image with a given mask using the server's process function.
Route: '/api/lama/paint'
Method: POST
Parameters:
img: The input image file (JPEG or PNG format).
mask: The mask file (JPEG or PNG format).
Returns:
A JSON object containing the processed image in base64 format under the "image" key.
"""
img = await img.read()
mask = await mask.read()
return {"image": server.process(img, mask)}
@app.post("/api/remove")
async def remove(img: UploadFile = File()):
x = await img.read()
return {"image": server.remove(x)}
@app.post("/api/lama/model")
def switch_model(new_name: str):
return server.switch_model(new_name)
@app.get("/api/lama/model")
def current_model():
return server.current_model()
@app.get("/api/lama/switchmode")
def get_is_disable_model_switch():
return server.get_is_disable_model_switch()
@app.on_event("startup")
def init_data():
model_device = "cpu"
global yolo_model
# TODO Update for local development
# yolo_model = YOLO('yolov8x-seg.pt')
yolo_model = YOLO('/app/yolov8x-seg.pt')
yolo_model.to(model_device)
print(f"YOLO model yolov8x-seg.pt loaded.")
server.initModel()
def create_app(args):
"""
Creates the FastAPI app and adds the endpoints.
Args:
args: The arguments.
"""
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--model_name', type=str, default='lama', help='Model name')
parser.add_argument('--host', type=str, default="0.0.0.0")
parser.add_argument('--port', type=int, default=5000)
parser.add_argument('--reload', type=bool, default=True)
parser.add_argument('--model_device', type=str, default='cpu', help='Model device')
parser.add_argument('--disable_model_switch', type=bool, default=False, help='Disable model switch')
parser.add_argument('--gui', type=bool, default=False, help='Enable GUI')
parser.add_argument('--cpu_offload', type=bool, default=False, help='Enable CPU offload')
parser.add_argument('--disable_nsfw', type=bool, default=False, help='Disable NSFW')
parser.add_argument('--enable_xformers', type=bool, default=False, help='Enable xformers')
parser.add_argument('--hf_access_token', type=str, default='', help='Hugging Face access token')
parser.add_argument('--local_files_only', type=bool, default=False, help='Enable local files only')
parser.add_argument('--no_half', type=bool, default=False, help='Disable half')
parser.add_argument('--sd_cpu_textencoder', type=bool, default=False, help='Enable CPU text encoder')
parser.add_argument('--sd_disable_nsfw', type=bool, default=False, help='Disable NSFW')
parser.add_argument('--sd_enable_xformers', type=bool, default=False, help='Enable xformers')
parser.add_argument('--sd_run_local', type=bool, default=False, help='Enable local files only')
args = parser.parse_args()
create_app(args)