import torch.nn as nn def get_activation_layer(act_type): """get activation layer Args: act_type (str): the activation type Returns: torch.nn.functional: the activation layer """ if act_type == "gelu": return lambda: nn.GELU() elif act_type == "gelu_tanh": # Approximate `tanh` requires torch >= 1.13 return lambda: nn.GELU(approximate="tanh") elif act_type == "relu": return nn.ReLU elif act_type == "silu": return nn.SiLU else: raise ValueError(f"Unknown activation type: {act_type}")