File size: 3,012 Bytes
83d190a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from pathlib import Path
from typing import Any, Optional, Union

import torch
from safetensors import safe_open
from safetensors.torch import save_file

from style_bert_vits2.logging import logger


def load_safetensors(
    checkpoint_path: Union[str, Path],
    model: torch.nn.Module,
    for_infer: bool = False,
) -> tuple[torch.nn.Module, Optional[int]]:
    """
    指定されたパスから safetensors モデルを読み込み、モデルとイテレーションを返す。

    Args:
        checkpoint_path (Union[str, Path]): モデルのチェックポイントファイルのパス
        model (torch.nn.Module): 読み込む対象のモデル
        for_infer (bool): 推論用に読み込むかどうかのフラグ

    Returns:
        tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション回数(存在する場合)
    """

    tensors: dict[str, Any] = {}
    iteration: Optional[int] = None
    with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f:  # type: ignore
        for key in f.keys():
            if key == "iteration":
                iteration = f.get_tensor(key).item()
            tensors[key] = f.get_tensor(key)
    if hasattr(model, "module"):
        result = model.module.load_state_dict(tensors, strict=False)
    else:
        result = model.load_state_dict(tensors, strict=False)
    for key in result.missing_keys:
        if key.startswith("enc_q") and for_infer:
            continue
        logger.warning(f"Missing key: {key}")
    for key in result.unexpected_keys:
        if key == "iteration":
            continue
        logger.warning(f"Unexpected key: {key}")
    if iteration is None:
        logger.info(f"Loaded '{checkpoint_path}'")
    else:
        logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")

    return model, iteration


def save_safetensors(
    model: torch.nn.Module,
    iteration: int,
    checkpoint_path: Union[str, Path],
    is_half: bool = False,
    for_infer: bool = False,
) -> None:
    """
    モデルを safetensors 形式で保存する。

    Args:
        model (torch.nn.Module): 保存するモデル
        iteration (int): イテレーション回数
        checkpoint_path (Union[str, Path]): 保存先のパス
        is_half (bool): モデルを半精度で保存するかどうかのフラグ
        for_infer (bool): 推論用に保存するかどうかのフラグ
    """

    if hasattr(model, "module"):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    keys = []
    for k in state_dict:
        if "enc_q" in k and for_infer:
            continue
        keys.append(k)

    new_dict = (
        {k: state_dict[k].half() for k in keys}
        if is_half
        else {k: state_dict[k] for k in keys}
    )
    new_dict["iteration"] = torch.LongTensor([iteration])
    logger.info(f"Saved safetensors to {checkpoint_path}")

    save_file(new_dict, checkpoint_path)