from io import BytesIO import litserve as ls import numpy as np from fastapi import Response, UploadFile from PIL import Image from lang_sam import LangSAM from lang_sam.utils import draw_image PORT = 8000 class LangSAMAPI(ls.LitAPI): def setup(self, device: str) -> None: """Initialize or load the LangSAM model.""" self.model = LangSAM(sam_type="sam2.1_hiera_small") print("LangSAM model initialized.") def decode_request(self, request) -> dict: """Decode the incoming request to extract parameters and image bytes. Assumes the request is sent as multipart/form-data with fields: - sam_type: str - box_threshold: float - text_threshold: float - text_prompt: str - image: UploadFile """ # Extract form data sam_type = request.get("sam_type") box_threshold = float(request.get("box_threshold", 0.3)) text_threshold = float(request.get("text_threshold", 0.25)) text_prompt = request.get("text_prompt", "") # Extract image file image_file: UploadFile = request.get("image") if image_file is None: raise ValueError("No image file provided in the request.") image_bytes = image_file.file.read() return { "sam_type": sam_type, "box_threshold": box_threshold, "text_threshold": text_threshold, "image_bytes": image_bytes, "text_prompt": text_prompt, } def predict(self, inputs: dict) -> dict: """Perform prediction using the LangSAM model. Yields: dict: Contains the processed output image. """ print("Starting prediction with parameters:") print( f"sam_type: {inputs['sam_type']}, \ box_threshold: {inputs['box_threshold']}, \ text_threshold: {inputs['text_threshold']}, \ text_prompt: {inputs['text_prompt']}" ) if inputs["sam_type"] != self.model.sam_type: print(f"Updating SAM model type to {inputs['sam_type']}") self.model.sam.build_model(inputs["sam_type"]) try: image_pil = Image.open(BytesIO(inputs["image_bytes"])).convert("RGB") except Exception as e: raise ValueError(f"Invalid image data: {e}") results = self.model.predict( images_pil=[image_pil], texts_prompt=[inputs["text_prompt"]], box_threshold=inputs["box_threshold"], text_threshold=inputs["text_threshold"], ) results = results[0] if not len(results["masks"]): print("No masks detected. Returning original image.") return {"output_image": image_pil} # Draw results on the image image_array = np.asarray(image_pil) output_image = draw_image( image_array, results["masks"], results["boxes"], results["scores"], results["labels"], ) output_image = Image.fromarray(np.uint8(output_image)).convert("RGB") return {"output_image": output_image} def encode_response(self, output: dict) -> Response: """Encode the prediction result into an HTTP response. Returns: Response: Contains the processed image in PNG format. """ try: image = output["output_image"] buffer = BytesIO() image.save(buffer, format="PNG") buffer.seek(0) return Response(content=buffer.getvalue(), media_type="image/png") except StopIteration: raise ValueError("No output generated by the prediction.") lit_api = LangSAMAPI() server = ls.LitServer(lit_api) if __name__ == "__main__": print(f"Starting LitServe and Gradio server on port {PORT}...") server.run(port=PORT)