mattricesound commited on
Commit
647e1a1
1 Parent(s): 9589cd1

Fix diffusion sampler not working on gpu

Browse files
Files changed (2) hide show
  1. models.py +1 -1
  2. shell_vars.sh +1 -1
models.py CHANGED
@@ -147,7 +147,7 @@ class DiffusionGenerationModel(nn.Module):
147
  return self.model(x)
148
 
149
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
150
- noise = torch.randn(x.shape)
151
  return self.model.sample(noise, num_steps=num_steps)
152
 
153
 
 
147
  return self.model(x)
148
 
149
  def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
150
+ noise = torch.randn(x.shape).to(x)
151
  return self.model.sample(noise, num_steps=num_steps)
152
 
153
 
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="/Users/matthewrice/Developer/remfx/data/egfx"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/egfx"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"