Spaces:
Runtime error
Runtime error
candlend
commited on
Commit
•
a78de09
1
Parent(s):
ae5cb64
consume_prefix_in_state_dict_if_present
Browse files- .gitignore +3 -1
- out_temp.wav +0 -0
- sovits/hubert_model.py +36 -1
.gitignore
CHANGED
@@ -157,4 +157,6 @@ cython_debug/
|
|
157 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
-
#.idea/
|
|
|
|
|
|
157 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
out_temp.wav
|
out_temp.wav
DELETED
Binary file (236 kB)
|
|
sovits/hubert_model.py
CHANGED
@@ -5,7 +5,42 @@ from typing import Optional, Tuple
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as t_func
|
8 |
-
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
class Hubert(nn.Module):
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as t_func
|
8 |
+
# from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
9 |
+
from typing import List, Dict, Any
|
10 |
+
|
11 |
+
def consume_prefix_in_state_dict_if_present(
|
12 |
+
state_dict: Dict[str, Any], prefix: str
|
13 |
+
) -> None:
|
14 |
+
r"""Strip the prefix in state_dict in place, if any.
|
15 |
+
|
16 |
+
..note::
|
17 |
+
Given a `state_dict` from a DP/DDP model, a local model can load it by applying
|
18 |
+
`consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
|
19 |
+
:meth:`torch.nn.Module.load_state_dict`.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
state_dict (OrderedDict): a state-dict to be loaded to the model.
|
23 |
+
prefix (str): prefix.
|
24 |
+
"""
|
25 |
+
keys = sorted(state_dict.keys())
|
26 |
+
for key in keys:
|
27 |
+
if key.startswith(prefix):
|
28 |
+
newkey = key[len(prefix) :]
|
29 |
+
state_dict[newkey] = state_dict.pop(key)
|
30 |
+
|
31 |
+
# also strip the prefix in metadata if any.
|
32 |
+
if "_metadata" in state_dict:
|
33 |
+
metadata = state_dict["_metadata"]
|
34 |
+
for key in list(metadata.keys()):
|
35 |
+
# for the metadata dict, the key can be:
|
36 |
+
# '': for the DDP module, which we want to remove.
|
37 |
+
# 'module': for the actual model.
|
38 |
+
# 'module.xx.xx': for the rest.
|
39 |
+
|
40 |
+
if len(key) == 0:
|
41 |
+
continue
|
42 |
+
newkey = key[len(prefix) :]
|
43 |
+
metadata[newkey] = metadata.pop(key)
|
44 |
|
45 |
|
46 |
class Hubert(nn.Module):
|