# coding=utf-8 # Copyright 2024 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from torch import nn from diffusers.utils import deprecate from diffusers.utils.import_utils import is_torch_npu_available if is_torch_npu_available(): import torch_npu ACTIVATION_FUNCTIONS = { "swish": nn.SiLU(), "silu": nn.SiLU(), "mish": nn.Mish(), "gelu": nn.GELU(), "relu": nn.ReLU(), } def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. Args: act_fn (str): Name of activation function. Returns: nn.Module: Activation function. """ act_fn = act_fn.lower() if act_fn in ACTIVATION_FUNCTIONS: return ACTIVATION_FUNCTIONS[act_fn] else: raise ValueError(f"Unsupported activation function: {act_fn}") class FP32SiLU(nn.Module): r""" SiLU activation function with input upcasted to torch.float32. """ def __init__(self): super().__init__() def forward(self, inputs: torch.Tensor) -> torch.Tensor: return F.silu(inputs.float(), inplace=False).to(inputs.dtype) class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class GEGLU(nn.Module): r""" A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = self.proj(hidden_states) if is_torch_npu_available(): # using torch_npu.npu_geglu can run faster and save memory on NPU. return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] else: hidden_states, gate = hidden_states.chunk(2, dim=-1) return hidden_states * self.gelu(gate) class SwiGLU(nn.Module): r""" A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` but uses SiLU / Swish instead of GeLU. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.activation = nn.SiLU() def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states, gate = hidden_states.chunk(2, dim=-1) return hidden_states * self.activation(gate) class ApproximateGELU(nn.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this [paper](https://arxiv.org/abs/1606.08415). Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x)