PommesPeter commited on
Commit
8236505
1 Parent(s): eeb2ff8

Update transport/integrators.py

Browse files
Files changed (1) hide show
  1. transport/integrators.py +6 -4
transport/integrators.py CHANGED
@@ -99,10 +99,12 @@ class ode:
99
 
100
  self.drift = drift
101
  self.t = th.linspace(t0, t1, num_steps)
102
- if time_shifting_factor:
103
- self.t = self.t / (
104
- self.t + time_shifting_factor - time_shifting_factor * self.t
105
- )
 
 
106
  self.atol = atol
107
  self.rtol = rtol
108
  self.sampler_type = sampler_type
 
99
 
100
  self.drift = drift
101
  self.t = th.linspace(t0, t1, num_steps)
102
+ if time_shifting_factor == 0:
103
+ t_1 = 1 / (1 + th.exp(-6 * (self.t - 0.6)))
104
+ t_2 = 1 - 1 / (1 + th.exp(20 * (self.t - 0.6)))
105
+ self.t = th.where(self.t < 0.6, t_1, t_2)
106
+ else:
107
+ self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
108
  self.atol = atol
109
  self.rtol = rtol
110
  self.sampler_type = sampler_type