File size: 2,175 Bytes
0252e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d8722
03b5980
0252e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI, UploadFile, File
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import io
from fastapi.responses import JSONResponse
import matplotlib.pyplot as plt
import uvicorn
import matplotlib
matplotlib.use('Agg')

app = FastAPI()

# Load the model and processor once, globally
processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("model/")

# Define the focal length and sensor width (adjust these values based on your camera)
focal_length = 14.35 
sensor_width = 4.88
image_width = 3072

focal_length_px = (image_width * focal_length) / sensor_width


@app.post("/predict/")
async def predict_depth(file: UploadFile = File(...)):
    # Read the uploaded image
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes))

    # Prepare image for the model
    inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        predicted_depth = outputs.predicted_depth

    # Interpolate to original size
    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False,
    )

    # Convert to numpy for further processing
    output = prediction.squeeze().cpu().numpy()

    # Normalize and format depth map for display
    formatted = (output * 255 / np.max(output)).astype("uint8")
    depth_map_image = Image.fromarray(formatted)

    # Convert depth to real-world centimeters using focal length and sensor width
    cm = focal_length_px / (output + 1e-6)

    # Save the depth map visualization to a buffer
    fig, ax = plt.subplots()
    heat = ax.imshow(cm, cmap="plasma")
    plt.colorbar(heat)
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)

    # Return the result as a JSON response
    return JSONResponse({
        "depth_map": f"data:image/png;base64,{base64.b64encode(buf.read()).decode()}"
    })

# For local testing
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)