Spaces:
Configuration error
Configuration error
| import io, json | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request | |
| from pydantic import BaseModel | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import numpy as np | |
| from PIL import Image | |
| import base64 | |
| from typing import List, Optional | |
| import torch | |
| from core.processing import embed_text, get_dino_boxes_from_prompt, get_sam_mask, expand_coords_shape, embed_image_dino_large | |
| from core.storage import query_vector_db_by_image_embedding, query_vector_db_by_text_embedding, get_object_info_from_graph, set_object_primary_location_hierarchy, get_object_location_chain | |
| from core.storage import get_object_owners, add_owner_by_person_id, add_owner_by_person_name, get_all_locations_for_house | |
| from core.image_processing import apply_mask, crop_to_mask_size, encode_image_to_base64 | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| class Point(BaseModel): | |
| x: float | |
| y: float | |
| class Point3D(BaseModel): | |
| x: float | |
| y: float | |
| z: float | |
| class MaskRequest(BaseModel): | |
| image_base64: str # base64 encoded PNG image | |
| points: List[Point] | |
| labels: List[int] | |
| prompt: str | |
| return_raw_mask: bool = False | |
| return_rgb_mask: bool = False | |
| return_embeddings: bool = False | |
| class BoundingBox(BaseModel): | |
| x: int | |
| y: int | |
| width: int | |
| height: int | |
| class MaskResponse(BaseModel): | |
| raw_mask_base64: str | |
| rgb_mask_base64: str | |
| embedding: List[float] | |
| bounding_box: BoundingBox | |
| class ObjectQueryByEmbeddingRequest(BaseModel): | |
| embedding_type: str # "image" or "text" | |
| embedding: List[float] | |
| k: int = 5 # default to 5 if not specified | |
| house_id: Optional[str] = None # Optional house ID to filter results | |
| class ObjectQueryByDescriptionRequest(BaseModel): | |
| description: str | |
| k: int = 5 | |
| house_id: str = None # Optional house ID to filter results | |
| class ObjectQueryResultEntry(BaseModel): | |
| object_id: str | |
| aggregated_similarity: float | |
| probability: float | |
| descriptions: List[str] | |
| class ObjectInfoRequest(BaseModel): | |
| house_id: str | |
| object_id: str | |
| class ObjectInfoResponse(BaseModel): | |
| object_id: str | |
| house_id: str | |
| description: str | |
| class SetPrimaryLocationRequest(BaseModel): | |
| house_id: str | |
| object_id: str | |
| location_hierarchy: List[str] # Example: ["Kitchen", "Left Upper Cabinet", "Middle Shelf"] | |
| class ObjectLocationRequest(BaseModel): | |
| house_id: str | |
| object_id: Optional[str] = None | |
| include_images: bool = False | |
| class LocationInfo(BaseModel): | |
| name: str | |
| image_uri: Optional[str] = None | |
| image_base64: Optional[str] = None | |
| location_x: Optional[float] = 0 | |
| location_y: Optional[float] = 0 | |
| location_z: Optional[float] = 0 | |
| shape: Optional[str] = None | |
| radius: Optional[float] = 0 | |
| height: Optional[float] = 0 | |
| width: Optional[float] = 0 | |
| depth: Optional[float] = 0 | |
| class ObjectLocationResponse(BaseModel): | |
| object_id: Optional[str] = None | |
| house_id: str | |
| locations: List[LocationInfo] | |
| class Person(BaseModel): | |
| person_id: str | |
| name: Optional[str] | |
| nickname: Optional [str] | |
| age: Optional[int] | |
| type: str = "person" # e.g., "person", "dog", "robot", etc. | |
| image_uri: Optional[str] = None | |
| class ObjectOwnersResponse(BaseModel): | |
| object_id: str | |
| house_id: str | |
| owners: List[Person] # Or a more complex model if needed | |
| class AddOwnerByIdRequest(BaseModel): | |
| house_id: str | |
| object_id: str | |
| person_id: str | |
| class AddOwnerByNameRequest(BaseModel): | |
| house_id: str | |
| object_id: str | |
| name: str | |
| type: Optional[str] = "person" | |
| async def log_requests(request: Request, call_next): | |
| print(f"[REQ] {request.method} {request.url}") | |
| return await call_next(request) | |
| async def root(): | |
| return {"message": "Hello, World!"} | |
| async def log_location(request: Point3D): | |
| try: | |
| print( | |
| f"[LogLocation] " | |
| f"x:{request.x:.2f} " | |
| f"y:{request.y:.2f} " | |
| f"z:{request.z:.2f}" | |
| ) | |
| response = "log location successful" | |
| return response | |
| except Exception as e: | |
| raise HTTPException(500, f"log location failed: {str(e)}") | |
| async def mask_endpoint(request: MaskRequest): | |
| try: | |
| # Decode base64 image | |
| image_bytes = base64.b64decode(request.image_base64) | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| img_np = np.array(img) | |
| # Convert points to numpy array | |
| point_coords = np.array([[p.x, p.y] for p in request.points], dtype=np.float32) | |
| point_labels = np.array(request.labels, dtype=np.int32) | |
| # Optionally get bounding boxes if a prompt is provided | |
| sam_boxes = None | |
| if request.prompt: | |
| sam_boxes = get_dino_boxes_from_prompt(img_np, request.prompt) | |
| point_coords, point_labels = expand_coords_shape(point_coords, point_labels, sam_boxes.shape[0]) | |
| # Generate the mask | |
| mask, bbox = get_sam_mask(img_np, None, None, sam_boxes) | |
| mask_img = apply_mask(img_np, mask, "remove") | |
| mask_img = crop_to_mask_size(mask_img, mask) | |
| # Encode images to base64 | |
| mask_raw_base64 = encode_image_to_base64(mask * 255) if request.return_raw_mask else "" | |
| masked_rgb_base64 = encode_image_to_base64(mask_img) if request.return_rgb_mask else "" | |
| embedding = embed_image_dino_large(mask_img).tolist() if request.return_embeddings else None | |
| response = MaskResponse( | |
| raw_mask_base64=mask_raw_base64, | |
| rgb_mask_base64=masked_rgb_base64, | |
| embedding=embedding, | |
| bounding_box=BoundingBox(**bbox) | |
| ) | |
| return response | |
| except Exception as e: | |
| raise HTTPException(500, f"Mask generation failed: {str(e)}") | |
| async def query_by_embedding(query: ObjectQueryByEmbeddingRequest): | |
| try: | |
| k = 5 #query.k | |
| if query.embedding_type == "text": | |
| query_vector = np.array(query.embedding, dtype=np.float32) | |
| results = query_vector_db_by_text_embedding(query_vector, k, query.house_id) | |
| elif query.embedding_type == "image": | |
| query_vector = np.array(query.embedding, dtype=np.float32) | |
| results = query_vector_db_by_image_embedding(query_vector, k, query.house_id) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid embedding type. Use 'text' or 'image'.") | |
| object_scores = {} | |
| object_views = {} | |
| for result in results: | |
| obj_id = result.payload.get("object_id") | |
| score = result.score | |
| desc = result.payload.get("description") or "No description available" | |
| if obj_id in object_scores: | |
| object_scores[obj_id] = max(object_scores[obj_id], score) | |
| object_views[obj_id].append(desc) | |
| else: | |
| object_scores[obj_id] = score | |
| object_views[obj_id] = [desc] | |
| all_scores = np.array(list(object_scores.values())) | |
| exp_scores = np.exp(all_scores) | |
| probabilities = exp_scores / np.sum(exp_scores) if np.sum(exp_scores) > 0 else np.zeros_like(exp_scores) | |
| results = [] | |
| for i, (obj_id, score) in enumerate(object_scores.items()): | |
| results.append(ObjectQueryResultEntry( | |
| object_id=obj_id, | |
| aggregated_similarity=float(score), | |
| probability=float(probabilities[i]), | |
| descriptions=object_views[obj_id] | |
| )) | |
| return results | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query_by_description(query: ObjectQueryByDescriptionRequest): | |
| try: | |
| # Embed the description to get the text embedding | |
| embedding_vector = embed_text(query.description) | |
| # Call your existing embedding-based query | |
| embedding_request = ObjectQueryByEmbeddingRequest( | |
| embedding_type="text", | |
| embedding=embedding_vector.tolist(), | |
| k=query.k, | |
| house_id=query.house_id | |
| ) | |
| return await query_by_embedding(embedding_request) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_object_info_endpoint(request: ObjectInfoRequest): | |
| description = get_object_info_from_graph(request.house_id, request.object_id) | |
| if description is None: | |
| raise HTTPException(status_code=404, detail="Object not found in household") | |
| return ObjectInfoResponse( | |
| object_id=request.object_id, | |
| house_id=request.house_id, | |
| description=description | |
| ) | |
| async def set_primary_location(request: SetPrimaryLocationRequest): | |
| try: | |
| set_object_primary_location_hierarchy( | |
| object_id=request.object_id, | |
| house_id=request.house_id, | |
| location_hierarchy=request.location_hierarchy | |
| ) | |
| return {"status": "success", "message": f"Primary location set for object {request.object_id}"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_object_location(request: ObjectLocationRequest): | |
| try: | |
| locations = get_object_location_chain( | |
| house_id=request.house_id, | |
| object_id=request.object_id, | |
| include_images=request.include_images | |
| ) | |
| return ObjectLocationResponse( | |
| object_id=request.object_id, | |
| house_id=request.house_id, | |
| locations=locations | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_object_location(request: ObjectLocationRequest): | |
| try: | |
| locations = get_all_locations_for_house( | |
| house_id=request.house_id, | |
| include_images=request.include_images | |
| ) | |
| return ObjectLocationResponse( | |
| house_id=request.house_id, | |
| locations=locations | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_object_owners_handler(request: ObjectLocationRequest): | |
| try: | |
| owners = get_object_owners( | |
| house_id=request.house_id, | |
| object_id=request.object_id | |
| ) | |
| return ObjectOwnersResponse( | |
| object_id=request.object_id, | |
| house_id=request.house_id, | |
| owners=owners | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_add_owner_by_id(request: AddOwnerByIdRequest): | |
| try: | |
| p = add_owner_by_person_id( | |
| house_id=request.house_id, | |
| object_id=request.object_id, | |
| person_id=request.person_id | |
| ) | |
| if not p: | |
| raise HTTPException(status_code=404, detail="Person not found.") | |
| return p | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_add_owner_by_name(request: AddOwnerByNameRequest): | |
| try: | |
| p = add_owner_by_person_name( | |
| house_id=request.house_id, | |
| object_id=request.object_id, | |
| name=request.name, | |
| type=request.type | |
| ) | |
| if not p: | |
| raise HTTPException(status_code=500, detail="Failed to create owner.") | |
| return p | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("api.hud_server:app", host="0.0.0.0", port=8000) | |