| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import comfy.model_management | 
					
					
						
						| 
							 | 
						import numbers | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						RMSNorm = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						try: | 
					
					
						
						| 
							 | 
						    rms_norm_torch = torch.nn.functional.rms_norm | 
					
					
						
						| 
							 | 
						    RMSNorm = torch.nn.RMSNorm | 
					
					
						
						| 
							 | 
						except: | 
					
					
						
						| 
							 | 
						    rms_norm_torch = None | 
					
					
						
						| 
							 | 
						    logging.warning("Please update pytorch to use native RMSNorm") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def rms_norm(x, weight=None, eps=1e-6): | 
					
					
						
						| 
							 | 
						    if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): | 
					
					
						
						| 
							 | 
						        if weight is None: | 
					
					
						
						| 
							 | 
						            return rms_norm_torch(x, (x.shape[-1],), eps=eps) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) | 
					
					
						
						| 
							 | 
						        if weight is None: | 
					
					
						
						| 
							 | 
						            return r | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if RMSNorm is None: | 
					
					
						
						| 
							 | 
						    class RMSNorm(torch.nn.Module): | 
					
					
						
						| 
							 | 
						        def __init__( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            normalized_shape, | 
					
					
						
						| 
							 | 
						            eps=1e-6, | 
					
					
						
						| 
							 | 
						            elementwise_affine=True, | 
					
					
						
						| 
							 | 
						            device=None, | 
					
					
						
						| 
							 | 
						            dtype=None, | 
					
					
						
						| 
							 | 
						        ): | 
					
					
						
						| 
							 | 
						            factory_kwargs = {"device": device, "dtype": dtype} | 
					
					
						
						| 
							 | 
						            super().__init__() | 
					
					
						
						| 
							 | 
						            if isinstance(normalized_shape, numbers.Integral): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                normalized_shape = (normalized_shape,)   | 
					
					
						
						| 
							 | 
						            self.normalized_shape = tuple(normalized_shape)   | 
					
					
						
						| 
							 | 
						            self.eps = eps | 
					
					
						
						| 
							 | 
						            self.elementwise_affine = elementwise_affine | 
					
					
						
						| 
							 | 
						            if self.elementwise_affine: | 
					
					
						
						| 
							 | 
						                self.weight = torch.nn.Parameter( | 
					
					
						
						| 
							 | 
						                    torch.empty(self.normalized_shape, **factory_kwargs) | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                self.register_parameter("weight", None) | 
					
					
						
						| 
							 | 
						            self.bias = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def forward(self, x): | 
					
					
						
						| 
							 | 
						            return rms_norm(x, self.weight, self.eps) | 
					
					
						
						| 
							 | 
						
 |