glenn-jocher
commited on
Commit
·
2377e5f
1
Parent(s):
2b18924
FP16 EMA bug fix
Browse files- utils/torch_utils.py +2 -2
utils/torch_utils.py
CHANGED
@@ -195,8 +195,8 @@ class ModelEMA:
|
|
195 |
def __init__(self, model, decay=0.9999, updates=0):
|
196 |
# Create EMA
|
197 |
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
198 |
-
if next(model.parameters()).device.type != 'cpu':
|
199 |
-
|
200 |
self.updates = updates # number of EMA updates
|
201 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
202 |
for p in self.ema.parameters():
|
|
|
195 |
def __init__(self, model, decay=0.9999, updates=0):
|
196 |
# Create EMA
|
197 |
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
198 |
+
# if next(model.parameters()).device.type != 'cpu':
|
199 |
+
# self.ema.half() # FP16 EMA
|
200 |
self.updates = updates # number of EMA updates
|
201 |
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
|
202 |
for p in self.ema.parameters():
|