DepthPro / model.py
geetu040's picture
fix hard-coded cuda device
6f5f8f8
from PIL import Image
import torch
# custom installation from this PR: https://github.com/huggingface/transformers/pull/34583
# !pip install git+https://github.com/geetu040/transformers.git@depth-pro-projects#egg=transformers
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
# initialize processor and model
checkpoint = "geetu040/DepthPro"
revision = "project"
image_processor = DepthProImageProcessorFast.from_pretrained(checkpoint, revision=revision)
model = DepthProForDepthEstimation.from_pretrained(checkpoint, revision=revision)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def predict(image):
# inference
# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
# interpolate to original size
post_processed_output = image_processor.post_process_depth_estimation(
outputs, target_sizes=[(image.height, image.width)],
)
# visualize the prediction
depth = post_processed_output[0]["predicted_depth"]
depth = (depth - depth.min()) / depth.max()
depth = depth * 255.
depth = depth.detach().cpu().numpy()
depth = Image.fromarray(depth.astype("uint8"))
return depth