Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gc | |
| import warnings | |
| from collections.abc import Mapping | |
| from contextlib import contextmanager | |
| from typing import Optional, Union | |
| import numpy as np | |
| import torch | |
| from transformers import is_torch_npu_available, is_torch_xpu_available | |
| def flatten_dict(nested: dict, sep: str = "/") -> dict: | |
| """Flatten dictionary and concatenate nested keys with separator.""" | |
| def recurse(nest: dict, prefix: str, into: dict) -> None: | |
| for k, v in nest.items(): | |
| if sep in k: | |
| raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") | |
| if isinstance(v, Mapping): | |
| recurse(v, prefix + k + sep, into) | |
| else: | |
| into[prefix + k] = v | |
| flat = {} | |
| recurse(nested, "", flat) | |
| return flat | |
| def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: | |
| """Compute mean of tensor with a masked values.""" | |
| if axis is not None: | |
| return (values * mask).sum(axis=axis) / mask.sum(axis=axis) | |
| else: | |
| return (values * mask).sum() / mask.sum() | |
| def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: | |
| """Compute variance of tensor with masked values.""" | |
| mean = masked_mean(values, mask) | |
| centered_values = values - mean | |
| variance = masked_mean(centered_values**2, mask) | |
| if unbiased: | |
| mask_sum = mask.sum() | |
| if mask_sum == 0: | |
| raise ValueError( | |
| "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" | |
| "try increase the `mini_batch_size` or `gradient_accumulation_steps`" | |
| ) | |
| # note that if mask_sum == 1, then there is a division by zero issue | |
| # to avoid it you just need to use a larger minibatch_size | |
| bessel_correction = mask_sum / (mask_sum - 1) | |
| variance = variance * bessel_correction | |
| return variance | |
| def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: | |
| """Whiten values with masked values.""" | |
| mean, var = masked_mean(values, mask), masked_var(values, mask) | |
| whitened = (values - mean) * torch.rsqrt(var + 1e-8) | |
| if not shift_mean: | |
| whitened += mean | |
| return whitened | |
| class LengthSampler: | |
| """ | |
| Samples a length | |
| """ | |
| def __init__(self, min_value: int, max_value: int): | |
| self.values = list(range(min_value, max_value)) | |
| def __call__(self) -> int: | |
| return np.random.choice(self.values) | |
| class PPODecorators: | |
| optimize_device_cache = False | |
| def empty_device_cache(cls): | |
| yield | |
| if cls.optimize_device_cache: | |
| if is_torch_xpu_available(): | |
| gc.collect() | |
| torch.xpu.empty_cache() | |
| gc.collect() | |
| elif is_torch_npu_available(): | |
| gc.collect() | |
| torch.npu.empty_cache() | |
| gc.collect() | |
| elif torch.cuda.is_available(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def randn_tensor( | |
| shape: Union[tuple, list], | |
| generator: Optional[Union[list[torch.Generator], torch.Generator]] = None, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| layout: Optional[torch.layout] = None, | |
| ) -> torch.Tensor: | |
| """A helper function to create random tensors on the desired `device` with the desired `dtype`. When | |
| passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor | |
| is always created on the CPU. | |
| """ | |
| # device on which tensor is created defaults to device | |
| rand_device = device | |
| batch_size = shape[0] | |
| layout = layout or torch.strided | |
| device = device or torch.device("cpu") | |
| if generator is not None: | |
| gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type | |
| if gen_device_type != device.type and gen_device_type == "cpu": | |
| rand_device = "cpu" | |
| if device != "mps": | |
| warnings.warn( | |
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | |
| f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" | |
| f" slighly speed up this function by passing a generator that was created on the {device} device.", | |
| UserWarning, | |
| ) | |
| elif gen_device_type != device.type and gen_device_type == "cuda": | |
| raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") | |
| # make sure generator list of length 1 is treated like a non-list | |
| if isinstance(generator, list) and len(generator) == 1: | |
| generator = generator[0] | |
| if isinstance(generator, list): | |
| shape = (1,) + shape[1:] | |
| latents = [ | |
| torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) | |
| for i in range(batch_size) | |
| ] | |
| latents = torch.cat(latents, dim=0).to(device) | |
| else: | |
| latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) | |
| return latents | |