huntrezz's picture
Update app.py
772e909 verified
raw
history blame
3.02 kB
import cv2
import torch
import numpy as np
from transformers import DPTForDepthEstimation, DPTImageProcessor
import gradio as gr
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256", torch_dtype=torch.float32)
model.eval()
# Apply global unstructured pruning
parameters_to_prune = [
(module, "weight") for module in filter(lambda m: isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)), model.modules())
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2, # Prune 20% of weights
)
for module, _ in parameters_to_prune:
prune.remove(module, "weight")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8
)
model = model.to(device)
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
color_map = cv2.applyColorMap(np.arange(256, dtype=np.uint8), cv2.COLORMAP_INFERNO)
color_map = torch.from_numpy(color_map).to(device)
def preprocess_image(image):
image = cv2.resize(image, (128, 72))
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(device)
return image / 255.0
def plot_depth_map(depth_map, original_image):
fig = plt.figure(figsize=(16, 9))
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(range(depth_map.shape[1]), range(depth_map.shape[0]))
# Resize original image to match depth map dimensions
original_image_resized = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
colors = original_image_resized.reshape(depth_map.shape[0], depth_map.shape[1], 3) / 255.0
ax.plot_surface(x, y, depth_map, facecolors=colors, shade=False)
ax.set_zlim(0, 1)
# Adjust the view to look down at an angle from a higher position
ax.view_init(elev=45, azim=180) # 180-degree rotation and a higher angle
plt.axis('off')
plt.close(fig)
fig.canvas.draw()
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
@torch.inference_mode()
def process_frame(image):
if image is None:
return None
preprocessed = preprocess_image(image)
predicted_depth = model(preprocessed).predicted_depth
depth_map = predicted_depth.squeeze().cpu().numpy()
# Normalize depth map
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
# Convert BGR to RGB if necessary
if image.shape[2] == 3: # Check if it's a color image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return plot_depth_map(depth_map, image)
interface = gr.Interface(
fn=process_frame,
inputs=gr.Image(sources="webcam", streaming=True),
outputs="image",
live=True
)
interface.launch()