SayanoAI commited on
Commit
b7e26dd
1 Parent(s): 8bbea4d

moved imports to top

Browse files
Files changed (1) hide show
  1. loaders.py +40 -39
loaders.py CHANGED
@@ -1,40 +1,41 @@
1
- import torch
2
- from transformers import HubertModel, HubertConfig
3
-
4
- class HubertModelWithFinalProj(HubertModel):
5
- def __init__(self, config):
6
- super().__init__(config)
7
-
8
- # The final projection layer is only used for backward compatibility.
9
- # Following https://github.com/auspicious3000/contentvec/issues/6
10
- # Remove this layer is necessary to achieve the desired outcome.
11
- self.final_proj = torch.nn.Linear(config.hidden_size, config.classifier_proj_size)
12
-
13
- @staticmethod
14
- def load_safetensors(path: str, device="cpu"):
15
- assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'"
16
- from safetensors import safe_open
17
- import json
18
- with safe_open(path, framework="pt", device="cpu") as f:
19
- metadata = f.metadata()
20
- state_dict = {}
21
- for key in f.keys():
22
- state_dict[key] = f.get_tensor(key)
23
- model = HubertModelWithFinalProj(HubertConfig.from_dict(json.loads(metadata["config"])))
24
- model.load_state_dict(state_dict=state_dict)
25
- return model.to(device)
26
-
27
- def save_safetensors(self, path: str):
28
- assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'"
29
- import safetensors.torch as st
30
- import json
31
- with open(path,"wb") as f:
32
- state_dict = self.state_dict()
33
- f.write(st.save(state_dict,dict(config=json.dumps(self.config.to_dict()))))
34
-
35
- def extract_features(self, source: torch.Tensor, version="v2", **kwargs):
36
- with torch.no_grad():
37
- output_layer = 9 if version == "v1" else 12
38
- output = self(source.to(self.config.torch_dtype), output_hidden_states=True)["hidden_states"][output_layer]
39
- features = self.final_proj(output) if version == "v1" else output
 
40
  return features
 
1
+ import torch
2
+ from transformers import HubertModel, HubertConfig
3
+ import safetensors.torch as st
4
+ from safetensors import safe_open
5
+ import json
6
+
7
+ class HubertModelWithFinalProj(HubertModel):
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+
11
+ # The final projection layer is only used for backward compatibility.
12
+ # Following https://github.com/auspicious3000/contentvec/issues/6
13
+ # Remove this layer is necessary to achieve the desired outcome.
14
+ self.final_proj = torch.nn.Linear(config.hidden_size, config.classifier_proj_size)
15
+
16
+ @staticmethod
17
+ def load_safetensors(path: str, device="cpu"):
18
+ assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'"
19
+
20
+ with safe_open(path, framework="pt", device="cpu") as f:
21
+ metadata = f.metadata()
22
+ state_dict = {}
23
+ for key in f.keys():
24
+ state_dict[key] = f.get_tensor(key)
25
+ model = HubertModelWithFinalProj(HubertConfig.from_dict(json.loads(metadata["config"])))
26
+ model.load_state_dict(state_dict=state_dict)
27
+ return model.to(device)
28
+
29
+ def save_safetensors(self, path: str):
30
+ assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'"
31
+
32
+ with open(path,"wb") as f:
33
+ state_dict = self.state_dict()
34
+ f.write(st.save(state_dict,dict(config=json.dumps(self.config.to_dict()))))
35
+
36
+ def extract_features(self, source: torch.Tensor, version="v2", **kwargs):
37
+ with torch.no_grad():
38
+ output_layer = 9 if version == "v1" else 12
39
+ output = self(source.to(self.config.torch_dtype), output_hidden_states=True)["hidden_states"][output_layer]
40
+ features = self.final_proj(output) if version == "v1" else output
41
  return features