|
from functools import partial |
|
from typing import Any, List |
|
|
|
import bitsandbytes as bnb |
|
from torch import optim |
|
|
|
__all__ = ["Optimizers"] |
|
|
|
|
|
class Optimizers: |
|
"""Optimizers factory.""" |
|
|
|
_optimizers = { |
|
"Adam": optim.Adam, |
|
"AdamW": optim.AdamW, |
|
"SGD": partial(optim.SGD, momentum=0.9, nesterov=True), |
|
"RMSprop": partial(optim.RMSprop, momentum=0.9, alpha=0.9), |
|
"Adadelta": optim.Adadelta, |
|
"AdamW8bit": bnb.optim.Adam8bit, |
|
} |
|
|
|
@classmethod |
|
def names(cls) -> List[str]: |
|
return sorted(cls._optimizers.keys()) |
|
|
|
@classmethod |
|
def get(cls, name: str) -> Any: |
|
"""Access to Optimizers. |
|
|
|
Args: |
|
name: optimizer name |
|
Returns: |
|
A class to build the Optimizer |
|
""" |
|
return cls._optimizers.get(name) |
|
|