tori29umai's picture
Update app.py
18be359 verified
raw
history blame
No virus
4.52 kB
import spaces
import tempfile
import gradio as gr
import numpy as np
import torch
from PIL import Image
import trimesh
from huggingface_hub import hf_hub_download
from depth_anything_v2.dpt import DepthAnythingV2
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large',
'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
}
encoder = 'vitl'
model_name = encoder2name[encoder]
model = DepthAnythingV2(**model_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
title = "# Depth-Anything-V2-DepthPop"
description = """
このツールを使用すると、写真やイラストを飛び出す絵本風にすることができます。
"""
@spaces.GPU
def predict_depth(image):
return model.infer_image(image)
def generate_point_cloud(color_img, resolution):
depth_img = predict_depth(color_img[:, :, ::-1])
# 画像サイズの調整
height, width = color_img.shape[:2]
new_height = resolution
new_width = int(width * (new_height / height))
color_img_resized = np.array(Image.fromarray(color_img).resize((new_width, new_height), Image.LANCZOS))
depth_img_resized = np.array(Image.fromarray(depth_img).resize((new_width, new_height), Image.LANCZOS))
# 深度の調整
depth_min = np.min(depth_img_resized)
depth_max = np.max(depth_img_resized)
normalized_depth = (depth_img_resized - depth_min) / (depth_max - depth_min)
# 非線形変換(必要に応じて調整)
adjusted_depth = np.power(normalized_depth, 0.1) # ガンマ補正
# カメラの内部パラメータ(使用するカメラに基づいて調整)
fx, fy = 300, 300 # 焦点距離
cx, cy = color_img_resized.shape[1] / 2, color_img_resized.shape[0] / 2 # 主点
# メッシュグリッドの作成
rows, cols = adjusted_depth.shape
u, v = np.meshgrid(range(cols), range(rows))
# 3D座標の計算(X座標を反転)
Z = adjusted_depth
X = -((u - cx) * Z / fx) # X座標を反転
Y = (v - cy) * Z / fy
# X, Y, Z座標をスタック
points = np.stack((X, Y, Z), axis=-1)
# 点のリストに整形
points = points.reshape(-1, 3)
# 各点の色を取得
colors = color_img_resized.reshape(-1, 3)
# 色を0-1の範囲に正規化
colors = colors.astype(np.float32) / 255.0
# PointCloudオブジェクトの作成
cloud = trimesh.PointCloud(vertices=points, colors=colors)
# Z軸周りに180度回転を適用(時計回り)
rotation = trimesh.transformations.rotation_matrix(np.pi, [0, 0, 1])
cloud.apply_transform(rotation)
# Y軸周りに180度回転を適用(上下を反転)
flip_y = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
cloud.apply_transform(flip_y)
# GLB形式で保存
output_path = tempfile.mktemp(suffix='.glb')
cloud.export(output_path)
return output_path
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown("### Depth Prediction & Point Cloud Generation")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
with gr.Row():
resolution_slider = gr.Slider(minimum=512, maximum=1600, value=512, step=1, label="Resolution")
submit = gr.Button(value="Generate")
output_3d = gr.Model3D(
clear_color=[0.0, 0.0, 0.0, 0.0],
label="3D Model",
)
submit.click(fn=generate_point_cloud, inputs=[input_image, resolution_slider], outputs=[output_3d])
if __name__ == '__main__':
demo.queue().launch(share=True)