bluestyle97 commited on
Commit
498625f
·
verified ·
1 Parent(s): 88a8a53

Update freesplatter/utils/mesh_renderer.py

Browse files
freesplatter/utils/mesh_renderer.py CHANGED
@@ -289,14 +289,15 @@ class MeshRenderer(nn.Module):
289
  far=10,
290
  ssaa=1,
291
  texture_filter='linear-mipmap-linear',
292
- opengl=False):
 
293
  super().__init__()
294
  self.near = near
295
  self.far = far
296
  assert isinstance(ssaa, int) and ssaa >= 1
297
  self.ssaa = ssaa
298
  self.texture_filter = texture_filter
299
- self.glctx = dr.RasterizeGLContext() if opengl else dr.RasterizeCudaContext()
300
 
301
  def forward(self, meshes, poses, intrinsics, h, w, shading_fun=None,
302
  dilate_edges=0, normal_bg=[0.5, 0.5, 1.0], aa=True, render_vc=False):
 
289
  far=10,
290
  ssaa=1,
291
  texture_filter='linear-mipmap-linear',
292
+ opengl=False,
293
+ device='cuda'):
294
  super().__init__()
295
  self.near = near
296
  self.far = far
297
  assert isinstance(ssaa, int) and ssaa >= 1
298
  self.ssaa = ssaa
299
  self.texture_filter = texture_filter
300
+ self.glctx = dr.RasterizeCudaContext(device=device)
301
 
302
  def forward(self, meshes, poses, intrinsics, h, w, shading_fun=None,
303
  dilate_edges=0, normal_bg=[0.5, 0.5, 1.0], aa=True, render_vc=False):