candlend commited on
Commit
a78de09
1 Parent(s): ae5cb64

consume_prefix_in_state_dict_if_present

Browse files
Files changed (3) hide show
  1. .gitignore +3 -1
  2. out_temp.wav +0 -0
  3. sovits/hubert_model.py +36 -1
.gitignore CHANGED
@@ -157,4 +157,6 @@ cython_debug/
157
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
- #.idea/
 
 
 
157
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ out_temp.wav
out_temp.wav DELETED
Binary file (236 kB)
 
sovits/hubert_model.py CHANGED
@@ -5,7 +5,42 @@ from typing import Optional, Tuple
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as t_func
8
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class Hubert(nn.Module):
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as t_func
8
+ # from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
9
+ from typing import List, Dict, Any
10
+
11
+ def consume_prefix_in_state_dict_if_present(
12
+ state_dict: Dict[str, Any], prefix: str
13
+ ) -> None:
14
+ r"""Strip the prefix in state_dict in place, if any.
15
+
16
+ ..note::
17
+ Given a `state_dict` from a DP/DDP model, a local model can load it by applying
18
+ `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
19
+ :meth:`torch.nn.Module.load_state_dict`.
20
+
21
+ Args:
22
+ state_dict (OrderedDict): a state-dict to be loaded to the model.
23
+ prefix (str): prefix.
24
+ """
25
+ keys = sorted(state_dict.keys())
26
+ for key in keys:
27
+ if key.startswith(prefix):
28
+ newkey = key[len(prefix) :]
29
+ state_dict[newkey] = state_dict.pop(key)
30
+
31
+ # also strip the prefix in metadata if any.
32
+ if "_metadata" in state_dict:
33
+ metadata = state_dict["_metadata"]
34
+ for key in list(metadata.keys()):
35
+ # for the metadata dict, the key can be:
36
+ # '': for the DDP module, which we want to remove.
37
+ # 'module': for the actual model.
38
+ # 'module.xx.xx': for the rest.
39
+
40
+ if len(key) == 0:
41
+ continue
42
+ newkey = key[len(prefix) :]
43
+ metadata[newkey] = metadata.pop(key)
44
 
45
 
46
  class Hubert(nn.Module):