File size: 3,556 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
""" op.py """
import math
from packaging.version import parse as VersionParse

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm


def get_layer_norm(dim: int, layer_norm_type: str = "layer_norm", layer_norm_eps: float = 1e-5):
    """Get layer normalization layer.
    Args:
        dim (int): Feature dimension
        layer_norm_type (str): "layer_norm" or "rms_norm"
        layer_norm_eps (float): Epsilon value for numerical stability

    Returns:
        nn.Module: Layer normalization layer
    """
    if layer_norm_type == "rms_norm":
        # T5LayerNorm is equivalent to RMSNorm. https://arxiv.org/abs/1910.07467
        return RMSNorm(hidden_size=dim, eps=layer_norm_eps)
    else:
        return nn.LayerNorm(normalized_shape=dim, eps=layer_norm_eps)


def check_all_elements_equal(x: torch.Tensor) -> bool:
    return x.eq(x[0]).all().item()


def minmax_normalize(x: torch.Tensor, eps: float = 0.008) -> torch.FloatTensor:
    """Min-max normalization:

    x_norm = (x - x_min) / (x_max - x_min + eps)

    Args:
        x (torch.Tensor): (B, T, F)
    Returns:
        torch.Tensor: (B, T, F) with output range of [0, 1]
    """
    x_max = rearrange(x, "b t f -> b (t f)").max(1, keepdim=True)[0]
    x_min = rearrange(x, "b t f -> b (f t)").min(1, keepdim=True)[0]
    x_max = x_max[:, None, :]  # (B,1,1)
    x_min = x_min[:, None, :]  # (B,1,1)
    return (x - x_min) / (x_max - x_min + eps)


def count_parameters(model):
    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_params = sum(p.numel() for p in model.parameters())
    return num_trainable_params, num_params


def adjust_b_to_gcd(a, b, min_gcd=16):
    """
    Adjust the value of b to ensure the GCD(a, b) is at least min_gcd with minimum change to b.
    
    Parameters:
    - a (int): A positive integer
    - b (int): A positive integer
    - min_gcd (int): The minimum desired GCD
    
    Returns:
    - int: The adjusted value of b
    """
    current_gcd = math.gcd(a, b)

    # If current GCD is already greater than or equal to min_gcd, return b as it is.
    if current_gcd >= min_gcd:
        return b

    # If a is less than min_gcd, then it's impossible to get a GCD of at least min_gcd.
    if a < min_gcd:
        raise ValueError("a must be at least as large as min_gcd.")

    # Adjust b by trying increments and decrements, preferring the smallest absolute change.
    adjusted_b_up = b
    adjusted_b_down = b

    while True:
        adjusted_b_up += 1
        adjusted_b_down -= 1

        if math.gcd(a, adjusted_b_up) >= min_gcd:
            return adjusted_b_up
        elif math.gcd(a, adjusted_b_down) >= min_gcd:
            return adjusted_b_down


def optional_compiler_disable(func):
    if VersionParse(torch.__version__) >= VersionParse("2.1"):
        # If the version is 2.1 or higher, apply the torch.compiler.disable decorator.
        return torch.compiler.disable(func)
    else:
        # If the version is below 2.1, return the original function.
        return func


def optional_compiler_dynamic(func):
    return torch.compile(func, dynamic=True)