Spaces:
Build error
Build error
import os | |
import io | |
import base64 | |
import numpy as np | |
import torch | |
import time | |
from PIL import Image | |
from pydantic import BaseModel | |
from fastapi import FastAPI | |
from fastapi.responses import Response, JSONResponse | |
from fastapi.exceptions import HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from torchvision.transforms.functional import resize | |
from .model import ( | |
build_sam_predictor, | |
build_sam_hq_predictor, | |
build_mobile_sam_predictor, | |
get_multi_label_predictor, | |
) | |
from .data import Data | |
from .configs import DATA_ROOT, DEVICE, MODEL | |
from .transforms import ResizeLongestSide | |
from .mobile_sam.utils import batched_mask_to_box | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
if MODEL == "sam": | |
SAM = build_sam_predictor(checkpoint="sam_vit_h_4b8939.pth") | |
elif MODEL == "sam_hq": | |
SAM = build_sam_hq_predictor(checkpoint="sam_hq_vit_h.pth") | |
elif MODEL == "mobile_sam": | |
SAM = build_mobile_sam_predictor(checkpoint="mobile_sam.pth") | |
else: | |
raise ValueError(f"MODEL must be one of sam, sam_hq, got {MODEL}") | |
DATA = Data(DATA_ROOT / "data.pkl") | |
T = ResizeLongestSide(1024) | |
class SamQuery(BaseModel): | |
points: list[list[int]] | |
labels: list[int] | |
class MaskLabel(BaseModel): | |
mask: str | |
label: str | |
class Masks(BaseModel): | |
masks: list[str] | |
class MaskLabels(BaseModel): | |
masks: list[str] | |
labels: list[str] | |
class Box(BaseModel): | |
x: int | |
y: int | |
width: int | |
height: int | |
class Boxes(BaseModel): | |
bboxes: list[Box] | |
class MaskBoxes(BaseModel): | |
masks: list[str] | |
bboxes: list[Box] | |
class MaskBoxLabels(BaseModel): | |
masks: list[str] | |
bboxes: list[Box] | |
labels: list[str] | |
class ImageData(BaseModel): | |
image: str | |
async def index(): | |
return FileResponse(path=f"{os.environ['HOME']}/app/instance-labeler/out/index.html", media_type="text/html") | |
async def get_label_preds(image: str, q: SamQuery) -> MaskBoxes: | |
if image not in DATA: | |
raise HTTPException(status_code=404, detail="Image not found") | |
if MODEL == "sam" or MODEL == "mobile_sam": | |
SAM.features = torch.from_numpy(DATA.get_emb(image)).to(DEVICE) | |
elif MODEL == "sam_hq": | |
features = DATA.get_hq_emb(image) | |
SAM.features = torch.from_numpy(features[0]).to(DEVICE) | |
SAM.interm_features = [torch.from_numpy(f).to(DEVICE) for f in features[1:]] | |
meta_data = DATA.get_meta_data(image) | |
SAM.original_size = meta_data["original_size"] | |
SAM.input_size = meta_data["input_size"] | |
SAM.is_image_set = True # type: ignore | |
masks, _, _ = SAM.predict( # type: ignore | |
point_coords=np.array(q.points), | |
point_labels=np.array(q.labels), | |
multimask_output=False, | |
) | |
bboxes = batched_mask_to_box(torch.as_tensor(masks).to(DEVICE)).cpu().numpy() | |
bboxes = [ | |
Box(x=x1, y=y1, width=y2 - y1, height=x2 - x1) | |
for x1, y1, x2, y2 in bboxes.tolist() | |
] | |
masks_out = [] | |
for i in range(masks.shape[0]): | |
mask_i = masks[i, :, :] | |
mask_i = Image.fromarray(mask_i) | |
with io.BytesIO() as buf: | |
mask_i.save(buf, format="PNG") | |
mask_i = buf.getvalue() | |
masks_i_b64 = base64.b64encode(mask_i).decode("utf-8") | |
masks_out.append(masks_i_b64) | |
return MaskBoxes(masks=masks_out, bboxes=bboxes) | |
async def get_labels(image: str) -> MaskBoxLabels: | |
if image not in DATA: | |
raise HTTPException(status_code=404, detail="Image not found") | |
masks, bboxes, labels = DATA.get_labels(image) | |
if not masks: | |
raise HTTPException(status_code=404, detail="Label not found") | |
if len(masks) != len(labels): | |
raise HTTPException( | |
status_code=400, detail="Currupted data, masks not equal to labels" | |
) | |
out_masks = [] | |
for mask in masks: | |
with io.BytesIO() as buf: | |
mask.save(buf, format="PNG") | |
mask = buf.getvalue() | |
mask_b64 = base64.b64encode(mask).decode("utf-8") | |
out_masks.append(mask_b64) | |
bboxes = [Box(x=x1, y=y1, width=w, height=h) for x1, y1, h, w in bboxes] | |
return MaskBoxLabels(masks=out_masks, bboxes=bboxes, labels=labels) | |
async def get_multi_label_preds(image: str, q: MaskLabel) -> MaskBoxLabels: | |
if image not in DATA: | |
raise HTTPException(status_code=404, detail="Image not found") | |
image_pil = DATA.get_image(image) | |
image_np = np.array(image_pil.convert("RGB")) | |
mask_data = q.mask.replace("data:image/png;base64,", "") | |
mask = np.array(Image.open(io.BytesIO(base64.b64decode(mask_data))).convert("L")) | |
if mask.sum() == 0: | |
raise HTTPException(status_code=422, detail="Mask is empty") | |
per_sam_model = get_multi_label_predictor(SAM, image_np, mask) | |
start = time.perf_counter() | |
masks, bboxes, _ = per_sam_model(image_np) | |
print(f"inference time {time.perf_counter() - start}") | |
if masks is None: | |
return MaskBoxLabels(masks=[], bboxes=[], labels=[]) | |
masks_out = [] | |
for i in range(len(masks)): | |
mask_i = Image.fromarray(masks[i]) | |
with io.BytesIO() as buf: | |
mask_i.save(buf, format="PNG") | |
mask_i = buf.getvalue() | |
masks_i_b64 = base64.b64encode(mask_i).decode("utf-8") | |
masks_out.append(masks_i_b64) | |
bboxes = [ | |
Box(x=x1, y=y1, width=y2 - y1, height=x2 - x1) | |
for x1, y1, x2, y2 in bboxes.tolist() | |
] | |
return MaskBoxLabels( | |
masks=masks_out, bboxes=bboxes, labels=[q.label for _ in range(len(masks))] | |
) | |
async def label_image(image: str, mask_labels: MaskLabels) -> Response: | |
if image not in DATA: | |
raise HTTPException(status_code=404, detail="Image not found") | |
if len(mask_labels.masks) != len(mask_labels.labels): | |
raise HTTPException(status_code=400, detail="Invalid input") | |
save_masks = [] | |
for i in range(len(mask_labels.masks)): | |
mask_i = mask_labels.masks[i] | |
mask_i = mask_i.replace("data:image/png;base64,", "") | |
mask_i = Image.open(io.BytesIO(base64.b64decode(mask_i))).convert("L") | |
mask_i = mask_i.point(lambda p: 0 if <= 1 else p) | |
save_masks.append(mask_i) | |
bboxes = ( | |
batched_mask_to_box( | |
torch.as_tensor(np.array([np.array(m) for m in save_masks])) | |
.to(DEVICE) | |
.bool() | |
) | |
.cpu() | |
.numpy() | |
) | |
bboxes = [(x1, y1, (y2 - y1), (x2 - x1)) for x1, y1, x2, y2 in bboxes.tolist()] | |
DATA.save_labels(image, save_masks, bboxes, mask_labels.labels) | |
return Response(content="saved", media_type="text/plain") | |
async def get_image(image: str) -> Response: | |
if image not in DATA: | |
raise HTTPException(status_code=404, detail="Image not found") | |
image_ = DATA.get_image(image) | |
if not DATA.emb_exists(image): | |
SAM.set_image(np.asarray(image_.convert("RGB"))) # type: ignore | |
if MODEL == "sam" or MODEL == "mobile_sam": | |
features = SAM.get_image_embedding().detach().cpu().numpy() # type: ignore | |
DATA.save_emb(image, features) | |
elif MODEL == "sam_hq": | |
features = [SAM.features] + SAM.interm_features # type: ignore | |
DATA.save_hq_emb(image, [f.detach().cpu().numpy() for f in features]) | |
DATA.save_meta_data( | |
image, | |
{"original_size": SAM.original_size, "input_size": SAM.input_size}, | |
) | |
with io.BytesIO() as buf: | |
image_.save(buf, format="PNG") | |
image_ = buf.getvalue() | |
image_b64 = base64.b64encode(image_).decode("utf-8") | |
return Response(content=image_b64, media_type="image/png") | |
async def upload_image(image: str, image_data: ImageData) -> Response: | |
image_b64 = image_data.image | |
image_b64 = image_b64.replace("data:image/png;base64,", "") | |
image_b64 = image_b64.replace("data:image/jpeg;base64,", "") | |
if "data:image/" in image_b64: | |
raise HTTPException( | |
status_code=400, detail="Invalid image format, only accepts png and jpeg" | |
) | |
image_pil = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB") | |
# image_bytes = io.BytesIO(base64.b64decode(image_b64)) | |
# image_pil = Image.open(image_bytes) | |
target_size = T.get_preprocess_shape( | |
image_pil.size[1], image_pil.size[0], T.target_length | |
) | |
image_pil = resize(image_pil, target_size) | |
image_id = DATA.save_image(image, image_pil) | |
return Response(content=image_id, media_type="text/plain") | |
async def get_all_images() -> Response: | |
return JSONResponse(content={"images": DATA.get_all_images()}) | |
app.mount("/", StaticFiles(directory=f"{os.environ['HOME']}/app/instance-labeler/out", html=True), name="static") | |