|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
from pytorch3d.io import load_objs_as_meshes |
|
from pytorch3d.renderer import ( |
|
BlendParams, |
|
FoVPerspectiveCameras, |
|
look_at_view_transform, |
|
MeshRasterizer, |
|
MeshRenderer, |
|
PointLights, |
|
RasterizationSettings, |
|
SoftPhongShader, |
|
SoftSilhouetteShader, |
|
) |
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
DATA_DIR = os.path.join(current_dir, "..", "data", "cow_mesh") |
|
|
|
|
|
def generate_cow_renders( |
|
num_views: int = 40, data_dir: str = DATA_DIR, azimuth_range: float = 180 |
|
): |
|
""" |
|
This function generates `num_views` renders of a cow mesh. |
|
The renders are generated from viewpoints sampled at uniformly distributed |
|
azimuth intervals. The elevation is kept constant so that the camera's |
|
vertical position coincides with the equator. |
|
|
|
For a more detailed explanation of this code, please refer to the |
|
docs/tutorials/fit_textured_mesh.ipynb notebook. |
|
|
|
Args: |
|
num_views: The number of generated renders. |
|
data_dir: The folder that contains the cow mesh files. If the cow mesh |
|
files do not exist in the folder, this function will automatically |
|
download them. |
|
azimuth_range: number of degrees on each side of the start position to |
|
take samples |
|
|
|
Returns: |
|
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the |
|
images are rendered. |
|
images: A tensor of shape `(num_views, height, width, 3)` containing |
|
the rendered images. |
|
silhouettes: A tensor of shape `(num_views, height, width)` containing |
|
the rendered silhouettes. |
|
""" |
|
|
|
|
|
|
|
|
|
cow_mesh_files = [ |
|
os.path.join(data_dir, fl) for fl in ("cow.obj", "cow.mtl", "cow_texture.png") |
|
] |
|
if any(not os.path.isfile(f) for f in cow_mesh_files): |
|
os.makedirs(data_dir, exist_ok=True) |
|
os.system( |
|
f"wget -P {data_dir} " |
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj" |
|
) |
|
os.system( |
|
f"wget -P {data_dir} " |
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl" |
|
) |
|
os.system( |
|
f"wget -P {data_dir} " |
|
+ "https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" |
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
torch.cuda.set_device(device) |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
obj_filename = os.path.join(data_dir, "cow.obj") |
|
mesh = load_objs_as_meshes([obj_filename], device=device) |
|
|
|
|
|
|
|
|
|
|
|
verts = mesh.verts_packed() |
|
N = verts.shape[0] |
|
center = verts.mean(0) |
|
scale = max((verts - center).abs().max(0)[0]) |
|
mesh.offset_verts_(-(center.expand(N, 3))) |
|
mesh.scale_verts_((1.0 / float(scale))) |
|
|
|
|
|
elev = torch.linspace(0, 0, num_views) |
|
azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0 |
|
|
|
|
|
|
|
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]]) |
|
|
|
|
|
|
|
|
|
|
|
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim) |
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raster_settings = RasterizationSettings( |
|
image_size=128, blur_radius=0.0, faces_per_pixel=1 |
|
) |
|
|
|
|
|
|
|
|
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) |
|
renderer = MeshRenderer( |
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), |
|
shader=SoftPhongShader( |
|
device=device, cameras=cameras, lights=lights, blend_params=blend_params |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
meshes = mesh.extend(num_views) |
|
|
|
|
|
target_images = renderer(meshes, cameras=cameras, lights=lights) |
|
|
|
|
|
sigma = 1e-4 |
|
raster_settings_silhouette = RasterizationSettings( |
|
image_size=128, blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, faces_per_pixel=50 |
|
) |
|
|
|
|
|
renderer_silhouette = MeshRenderer( |
|
rasterizer=MeshRasterizer( |
|
cameras=cameras, raster_settings=raster_settings_silhouette |
|
), |
|
shader=SoftSilhouetteShader(), |
|
) |
|
|
|
|
|
|
|
silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights) |
|
|
|
|
|
silhouette_binary = (silhouette_images[..., 3] > 1e-4).float() |
|
|
|
return cameras, target_images[..., :3], silhouette_binary |
|
|