File size: 3,122 Bytes
977ddbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import itertools
from typing import List, Optional, Tuple, Union
import safetensors
import torch
from torch import Tensor
import os
from pathlib import Path
from omegaconf import DictConfig, OmegaConf


def get_parameter_device(parameter: torch.nn.Module):
    try:
        parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
        return next(parameters_and_buffers).device
    except StopIteration:
        # For torch.nn.DataParallel compatibility in PyTorch 1.5
        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            return tuples
        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        first_tuple = next(gen)
        return first_tuple[1].device


def get_parameter_dtype(parameter: torch.nn.Module):
    try:
        params = tuple(parameter.parameters())
        if len(params) > 0:
            return params[0].dtype

        buffers = tuple(parameter.buffers())
        if len(buffers) > 0:
            return buffers[0].dtype

    except StopIteration:
        # For torch.nn.DataParallel compatibility in PyTorch 1.5

        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            return tuples

        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        first_tuple = next(gen)
        return first_tuple[1].dtype


def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
    path_obj = Path(save_path)
    return path_obj.parent

def get_base_name(save_path: Union[str, os.PathLike]) -> str:
    path_obj = Path(save_path)
    return path_obj.name

def load_state_dict_from_path(path: Union[str, os.PathLike]):
    # Load a state dict from a path.
    if 'safetensors' in path:
        state_dict = safetensors.torch.load_file(path)
    else:
        state_dict = torch.load(path, map_location="cpu")
    return state_dict

def replace_extension(path, new_extension):
    if not new_extension.startswith('.'):
        new_extension = '.' + new_extension
    return os.path.splitext(path)[0] + new_extension

def make_config_path(save_path):
    config_path = replace_extension(save_path, '.yaml')
    return config_path

def save_config(config, config_path):
    assert isinstance(config, dict) or isinstance(config, DictConfig)
    os.makedirs(get_parent_directory(config_path), exist_ok=True)
    if isinstance(config, dict):
        config = OmegaConf.create(config)
    OmegaConf.save(config, config_path)


def save_state_dict_and_config(state_dict, config, save_path):
    os.makedirs(get_parent_directory(save_path), exist_ok=True)

    # save config dict
    config_path = make_config_path(save_path)
    save_config(config, config_path)

    # Save the model
    if 'safetensors' in save_path:
        safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
    else:
        torch.save(state_dict, save_path)