mrtlive's picture
server file and req
292c73a
from fastapi import FastAPI, status, File, Form, UploadFile
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
from starlette.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
import numpy as np
from io import BytesIO
from PIL import Image
from base64 import b64encode, b64decode
def pil_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = b64encode(buffered.getvalue()).decode("utf-8")
return img_str
sam_checkpoint = "sam_vit_b_01ec64.pth" # "sam_vit_l_0b3195.pth" or "sam_vit_h_4b8939.pth"
model_type = "vit_b" # "vit_l" or "vit_h"
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
print("Loading model")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
print("Finishing loading")
predictor = SamPredictor(sam)
app = FastAPI(debug=True)
origins = [
"http://localhost",
"http://localhost:8000",
"http://127.0.0.1",
"http://127.0.0.1:8000",
"http://localhost:5173",
"http://127.0.0.1:5173",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
input_point = []
input_label = []
masks = []
mask_input = [None]
@app.post("/image")
async def process_images(
image: UploadFile = File(...)
):
global input_point, input_label, mask_input, masks
input_point = []
input_label = []
masks = []
# mask_input = [None]
# Read the image and mask data as bytes
image_data = await image.read()
image_data = BytesIO(image_data)
img = np.array(Image.open(image_data))
print("get image", img.shape)
# produce an image embedding by calling SamPredictor.set_image
predictor.set_image(img[:,:,:-1])
print("finish setting image")
# Return a JSON response
return JSONResponse(
content={
"message": "Images received successfully",
},
status_code=200,
)
@app.post("/undo")
async def process_images():
global input_point, input_label, mask_input
input_point.pop()
input_label.pop()
masks.pop()
# mask_input.pop()
return JSONResponse(
content={
"message": "Clear successfully",
},
status_code=200,
)
@app.post("/click")
async def click_images(
x: int = Form(...), # horizontal
y: int = Form(...) # vertical
):
global input_point, input_label, mask_input
input_point.append([x, y])
input_label.append(1)
print("get click", x, y)
print("input_point", input_point)
print("input_label", input_label)
masks_, scores_, logits_ = predictor.predict(
point_coords=np.array([input_point[-1]]),
point_labels=np.array([input_label[-1]]),
# mask_input=mask_input[-1],
multimask_output=True, # SAM outputs 3 masks, we choose the one with highest score
)
# mask_input.append(logits[np.argmax(scores), :, :][None, :, :])
masks.append(masks_[np.argmax(scores_), :, :])
res = np.zeros(masks[0].shape)
for mask in masks:
res = np.logical_or(res, mask)
res = Image.fromarray(res)
# res.save("res.png")
# Return a JSON response
return JSONResponse(
content={
"masks": pil_image_to_base64(res),
"message": "Images processed successfully"
},
status_code=200,
)
@app.post("/rect")
async def rect_images(
start_x: int = Form(...), # horizontal
start_y: int = Form(...), # vertical
end_x: int = Form(...), # horizontal
end_y: int = Form(...) # vertical
):
masks_, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array([[start_x, start_y, end_x, end_y]]),
multimask_output=False
)
res = Image.fromarray(masks_[0])
# res.save("res.png")
# Return a JSON response
return JSONResponse(
content={
"masks": pil_image_to_base64(res),
"message": "Images processed successfully"
},
status_code=200,
)
@app.get('/')
def home():
return 'This is API for uses Segment-Anything Model from facebook. You can use it to segment anything.'
import uvicorn
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=7860)